Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: add cosmology to CBC result #867

Merged
merged 14 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 142 additions & 15 deletions bilby/core/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import pandas as pd

from .log import logger
from .introspection import infer_args_from_method


def check_directory_exists_and_if_not_mkdir(directory):
Expand Down Expand Up @@ -53,8 +52,8 @@ def default(self, obj):
return encode_astropy_cosmology(obj)
if isinstance(obj, units.Quantity):
return encode_astropy_quantity(obj)
if isinstance(obj, units.PrefixUnit):
return str(obj)
if isinstance(obj, (units.PrefixUnit, units.UnitBase, units.FunctionUnitBase)):
return encode_astropy_unit(obj)
except ImportError:
logger.debug("Cannot import astropy, cannot write cosmological priors")
if isinstance(obj, np.ndarray):
Expand Down Expand Up @@ -87,46 +86,133 @@ def default(self, obj):


def encode_astropy_cosmology(obj):
cls_name = obj.__class__.__name__
dct = {key: getattr(obj, key) for key in infer_args_from_method(obj.__init__)}
dct["__cosmology__"] = True
dct["__name__"] = cls_name
return dct
"""Encode an astropy cosmology object to a dictionary.

Adds the key :code:`__cosmology__` to the dictionary to indicate that the
object is a cosmology object.

.. versionchange:: 2.5.0
Now uses the :code:`to_format("mapping")` method to encode the
cosmology object.
"""
return {"__cosmology__": True, **obj.to_format("mapping")}


def encode_astropy_quantity(dct):
dct = dict(__astropy_quantity__=True, value=dct.value, unit=str(dct.unit))
"""Encode an astropy quantity object to a dictionary.

Adds the key :code:`__astropy_quantity__` to the dictionary to indicate that
the object is a quantity object.
"""
dct = dict(__astropy_quantity__=True, value=dct.value, unit=dct.unit.to_string())
if isinstance(dct["value"], np.ndarray):
dct["value"] = list(dct["value"])
dct["value"] = dct["value"].tolist()
return dct


def encode_astropy_unit(obj):
"""Encode an astropy unit object to a dictionary.

Adds the key :code:`__astropy_unit__` to the dictionary to indicate that the
object is a unit object.

.. versionadded:: 2.5.0
"""
try:
from astropy import units

# Based on the JsonCustomEncoder in astropy.units.misc
if obj == units.dimensionless_unscaled:
return dict(__astropy_unit__=True, unit="dimensionless_unit")
return dict(__astropy_unit__=True, unit=obj.to_string())

except ImportError:
logger.debug(
"Cannot import astropy, cosmological priors may not be properly loaded."
)
return obj


def decode_astropy_cosmology(dct):
"""Decode an astropy cosmology from a dictionary.

The dictionary should have been encoded using
:py:func:`~bibly.core.utils.io.encode_astropy_cosmology` and should have the
key :code:`__cosmology__`.

.. versionchange:: 2.5.0
Now uses the :code:`from_format` method to decode the cosmology object.
Still supports decoding result files that used the previous encoding.
"""
try:
from astropy import cosmology as cosmo

cosmo_cls = getattr(cosmo, dct["__name__"])
del dct["__cosmology__"], dct["__name__"]
return cosmo_cls(**dct)
del dct["__cosmology__"]
return cosmo.Cosmology.from_format(dct, format="mapping")
except ImportError:
logger.debug(
"Cannot import astropy, cosmological priors may not be " "properly loaded."
"Cannot import astropy, cosmological priors may not be properly loaded."
)
return dct
except KeyError:
# Support decoding result files that used the previous encoding
logger.warning(
"Failed to decode cosmology, falling back to legacy decoding. "
"Support for legacy decoding will be removed in a future release."
)
cosmo_cls = getattr(cosmo, dct["__name__"])
del dct["__name__"]
return cosmo_cls(**dct)


def decode_astropy_quantity(dct):
"""Decode an astropy quantity from a dictionary.

The dictionary should have been encoded using
:py:func:`~bilby.core.utils.io.encode_astropy_quantity` and should have the
key :code:`__astropy_quantity__`.
"""
try:
from astropy import units
from astropy.cosmology import units as cosmo_units

# Enable cosmology units such as redshift
units.add_enabled_units(cosmo_units)
if dct["value"] is None:
return None
else:
del dct["__astropy_quantity__"]
return units.Quantity(**dct)
except ImportError:
logger.debug(
"Cannot import astropy, cosmological priors may not be " "properly loaded."
"Cannot import astropy, cosmological priors may not be properly loaded."
)
return dct


def decode_astropy_unit(dct):
"""Decode an astropy unit from a dictionary.

The dictionary should have been encoded using
:py:func:`~bilby.core.utils.io.encode_astropy_unit` and should have the
key :code:`__astropy_unit__`.

.. versionadded:: 2.5.0
"""
try:
from astropy import units
from astropy.cosmology import units as cosmo_units

# Enable cosmology units such as redshift
units.add_enabled_units(cosmo_units)
if dct["unit"] == "dimensionless_unit":
return units.dimensionless_unscaled
else:
del dct["__astropy_unit__"]
return units.Unit(dct["unit"])
except ImportError:
logger.debug(
"Cannot import astropy, cosmological priors may not be properly loaded."
)
return dct

Expand Down Expand Up @@ -170,6 +256,8 @@ def decode_bilby_json(dct):
return decode_astropy_cosmology(dct)
if dct.get("__astropy_quantity__", False):
return decode_astropy_quantity(dct)
if dct.get("__astropy_unit__", False):
return decode_astropy_unit(dct)
if dct.get("__array__", False):
return np.asarray(dct["content"])
if dct.get("__complex__", False):
Expand Down Expand Up @@ -264,6 +352,13 @@ def encode_for_hdf5(key, item):
"""
from ..prior.dict import PriorDict

try:
from astropy import cosmology as cosmo, units
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not wild about having this import here, I think it could be avoided, but since it is only a debug message I'm happy to let it slide for a while.

unrelated: We could consider changing this function to use multiple dispatch at some point down the road, that would avoid ugly import tests and also allow downstream users to add their own encoding/decoding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree. At the time, I couldn't come up with a clean alternative.

On multiple dispatch: that's an interesting option, might be something to bring up on a dev call.

except ImportError:
logger.debug("Cannot import astropy, cannot write cosmological priors")
cosmo = None
units = None

if isinstance(item, np.int_):
item = int(item)
elif isinstance(item, np.float64):
Expand Down Expand Up @@ -302,6 +397,12 @@ def encode_for_hdf5(key, item):
output = json.dumps(item._get_json_dict())
elif isinstance(item, pd.DataFrame):
output = item.to_dict(orient="list")
elif cosmo is not None and isinstance(item, cosmo.FLRW):
output = encode_astropy_cosmology(item)
elif units is not None and isinstance(item, units.Quantity):
output = encode_astropy_quantity(item)
elif units is not None and isinstance(item, (units.PrefixUnit, units.UnitBase, units.FunctionUnitBase)):
output = encode_astropy_unit(item)
elif inspect.isfunction(item) or inspect.isclass(item):
output = dict(
__module__=item.__module__, __name__=item.__name__, __class__=True
Expand All @@ -317,12 +418,34 @@ def encode_for_hdf5(key, item):
return output


def decode_hdf5_dict(output):
"""Decode a dictionary constructed from a HDF5 file.

This handles decoding of Bilby types and astropy types from the dictionary.

.. versionadded:: 2.5.0
"""
if ("__function__" in output) or ("__class__" in output):
default = ".".join([output["__module__"], output["__name__"]])
output = getattr(import_module(output["__module__"]), output["__name__"], default)
elif "__cosmology__" in output:
output = decode_astropy_cosmology(output)
elif "__astropy_quantity__" in output:
output = decode_astropy_quantity(output)
elif "__astropy_unit__" in output:
output = decode_astropy_unit(output)
return output


def recursively_load_dict_contents_from_group(h5file, path):
"""
Recursively load a HDF5 file into a dictionary

.. versionadded:: 1.1.0

.. versionchanged: 2.5.0
Now decodes astropy and bilby types

Parameters
----------
h5file: h5py.File
Expand All @@ -345,6 +468,10 @@ def recursively_load_dict_contents_from_group(h5file, path):
output[key] = recursively_load_dict_contents_from_group(
h5file, path + key + "/"
)
# Some items may be encoded as dictionaries, so we need to decode them
# after the dictionary has been constructed.
# This includes decoding astropy and bilby types
output = decode_hdf5_dict(output)
return output


Expand Down
20 changes: 20 additions & 0 deletions bilby/gw/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ class CompactBinaryCoalescenceResult(CoreResult):
of compact binaries.
"""
def __init__(self, **kwargs):

if "meta_data" not in kwargs:
kwargs["meta_data"] = dict()
if "global_meta_data" not in kwargs:
kwargs["meta_data"]["global_meta_data"] = dict()
# Ensure cosmology is always stored in the meta_data
if "cosmology" not in kwargs["meta_data"]["global_meta_data"]:
from .cosmology import get_cosmology
kwargs["meta_data"]["global_meta_data"]["cosmology"] = get_cosmology()

super(CompactBinaryCoalescenceResult, self).__init__(**kwargs)

def __get_from_nested_meta_data(self, *keys):
Expand Down Expand Up @@ -117,6 +127,16 @@ def parameter_conversion(self):
return self.__get_from_nested_meta_data(
'likelihood', 'parameter_conversion')

@property
def cosmology(self):
"""The global cosmology used in the analysis.

.. versionadded:: 2.5.0
"""
return self.__get_from_nested_meta_data(
'global_meta_data', 'cosmology'
)

def detector_injection_properties(self, detector):
""" Returns a dictionary of the injection properties for each detector

Expand Down
35 changes: 28 additions & 7 deletions test/core/result_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import shutil
import os
import json
import pytest
from unittest.mock import patch

import bilby
from bilby.core.result import ResultError


class TestJson(unittest.TestCase):

def setUp(self):
self.encoder = bilby.core.utils.BilbyJsonEncoder
self.decoder = bilby.core.utils.decode_bilby_json
Expand Down Expand Up @@ -50,6 +52,12 @@ def test_dataframe_encoding(self):


class TestResult(unittest.TestCase):

@pytest.fixture(autouse=True)
def init_outdir(self, tmp_path):
# Use pytest's tmp_path fixture to create a temporary directory
self.outdir = str(tmp_path / "test")

def setUp(self):
np.random.seed(7)
bilby.utils.command_line_args.bilby_test_mode = False
Expand All @@ -61,7 +69,6 @@ def setUp(self):
d=2,
)
)
self.outdir = "test_outdir"
result = bilby.core.result.Result(
label="label",
outdir=self.outdir,
Expand Down Expand Up @@ -543,18 +550,32 @@ class NotAResult(bilby.core.result.Result):
pass

result = bilby.run_sampler(
likelihood, priors, sampler='bilby_mcmc', nsamples=10, L1steps=1,
proposal_cycle="default_noGMnoKD", printdt=1,
likelihood,
priors,
sampler='bilby_mcmc',
outdir=self.outdir,
nsamples=10,
L1steps=1,
proposal_cycle="default_noGMnoKD",
printdt=1,
check_point_plot=False,
result_class=NotAResult)
result_class=NotAResult
)
# result should be specified result_class
assert isinstance(result, NotAResult)

cached_result = bilby.run_sampler(
likelihood, priors, sampler='bilby_mcmc', nsamples=10, L1steps=1,
proposal_cycle="default_noGMnoKD", printdt=1,
likelihood,
priors,
sampler='bilby_mcmc',
outdir=self.outdir,
nsamples=10,
L1steps=1,
proposal_cycle="default_noGMnoKD",
printdt=1,
check_point_plot=False,
result_class=NotAResult)
result_class=NotAResult
)

# so should a result loaded from cache
assert isinstance(cached_result, NotAResult)
Expand Down
Loading
Loading