diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c936e1e3..94962026 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -54,6 +54,7 @@ jobs: - name: "Setup Micromamba" uses: mamba-org/setup-micromamba@v1 with: + micromamba-binary-path: ~/.local/bin/micromamba environment-file: environment.yml environment-name: gufe cache-environment: true diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index df4e99d5..e9c20335 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -29,7 +29,8 @@ jobs: cache-environment: true cache-downloads: true create-args: >- - python=3.9 + python=3.10 + rdkit=2023.09.5 init-shell: bash - name: "Install steps" diff --git a/docs/environment.yaml b/docs/environment.yaml index dba1eabc..1cdea9da 100644 --- a/docs/environment.yaml +++ b/docs/environment.yaml @@ -9,4 +9,4 @@ dependencies: - sphinx - openmm - pip: - - git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@main + - git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@a45f3edd5bc3e973c1a01b577c71efa1b62a65d6 diff --git a/gufe/tests/test_ligandatommapping.py b/gufe/tests/test_ligandatommapping.py index edd7b75b..17a564bd 100644 --- a/gufe/tests/test_ligandatommapping.py +++ b/gufe/tests/test_ligandatommapping.py @@ -275,7 +275,7 @@ def test_too_small_B(self, molA, molB): class TestLigandAtomMapping(GufeTokenizableTestsMixin): cls = LigandAtomMapping repr = "LigandAtomMapping(componentA=SmallMoleculeComponent(name=), componentB=SmallMoleculeComponent(name=), componentA_to_componentB={0: 0, 1: 1}, annotations={'foo': 'bar'})" - key = "LigandAtomMapping-c333723fbbee702c641cb9dca9beae49" + key = "LigandAtomMapping-2c0aae226e3f69d2d1cf429abaefdb5b" @pytest.fixture def instance(self, annotated_simple_mapping): @@ -289,4 +289,4 @@ def test_id_key(self, instance): def test_keyed_dict(self, instance): i2 = self.cls.from_dict(instance.to_dict()) - assert instance.to_keyed_dict() == i2.to_keyed_dict() \ No newline at end of file + assert instance.to_keyed_dict() == i2.to_keyed_dict() diff --git a/gufe/tests/test_smallmoleculecomponent.py b/gufe/tests/test_smallmoleculecomponent.py index 9341824e..17a1f32c 100644 --- a/gufe/tests/test_smallmoleculecomponent.py +++ b/gufe/tests/test_smallmoleculecomponent.py @@ -72,7 +72,7 @@ def test_ensure_ofe_name(internal, rdkit_name, name, expected, recwarn): class TestSmallMoleculeComponent(GufeTokenizableTestsMixin): cls = SmallMoleculeComponent - key = "SmallMoleculeComponent-51068a89f4793e688ee26135a9b7fbb6" + key = "SmallMoleculeComponent-82d90fcdcbe76a4155b0ea42b9080ff2" repr = "SmallMoleculeComponent(name=ethane)" @pytest.fixture diff --git a/gufe/tests/test_tokenization.py b/gufe/tests/test_tokenization.py index 99ed9d21..8ccf95bb 100644 --- a/gufe/tests/test_tokenization.py +++ b/gufe/tests/test_tokenization.py @@ -10,7 +10,8 @@ from gufe.tokenization import ( GufeTokenizable, GufeKey, tokenize, TOKENIZABLE_REGISTRY, import_qualname, get_class, TOKENIZABLE_CLASS_REGISTRY, JSON_HANDLER, - get_all_gufe_objs, + get_all_gufe_objs, gufe_to_digraph, gufe_objects_from_shallow_dict, + KeyedChain, ) @@ -380,6 +381,65 @@ def test_token(self): assert k.token == 'bar' +def test_gufe_to_digraph(solvated_complex): + graph = gufe_to_digraph(solvated_complex) + + connected_objects = gufe_objects_from_shallow_dict( + solvated_complex.to_shallow_dict() + ) + + assert len(graph.nodes) == 4 + assert len(graph.edges) == 3 + + for node_a, node_b in graph.edges: + assert node_b in connected_objects + assert node_a is solvated_complex + + +def test_gufe_objects_from_shallow_dict(solvated_complex): + shallow_dict = solvated_complex.to_shallow_dict() + gufe_objects = set(gufe_objects_from_shallow_dict(shallow_dict)) + + assert len(gufe_objects) == 3 + assert set(gufe_objects) == set(solvated_complex.components.values()) + + +class TestKeyedChain: + + def test_from_gufe(self, benzene_variants_star_map): + contained_objects = list(get_all_gufe_objs(benzene_variants_star_map)) + expected_len = len(contained_objects) + + kc = KeyedChain.from_gufe(benzene_variants_star_map) + + assert len(kc) == expected_len + + original_keys = [obj.key for obj in contained_objects] + original_keyed_dicts = [ + obj.to_keyed_dict() for obj in contained_objects + ] + + kc_gufe_keys = set(kc.gufe_keys()) + kc_keyed_dicts = list(kc.keyed_dicts()) + + assert kc_gufe_keys == set(original_keys) + + for key, keyed_dict in zip(original_keys, original_keyed_dicts): + assert key in kc_gufe_keys + assert keyed_dict in kc_keyed_dicts + + def test_to_gufe(self, benzene_variants_star_map): + kc = KeyedChain.from_gufe(benzene_variants_star_map) + assert hash(kc.to_gufe()) == hash(benzene_variants_star_map) + + def test_get_item(self, benzene_variants_star_map): + kc = KeyedChain.from_gufe(benzene_variants_star_map) + + assert kc[0] == kc._keyed_chain[0] + assert kc[-1] == kc._keyed_chain[-1] + assert kc[:] == kc._keyed_chain[:] + + def test_datetime_to_json(): d = datetime.datetime.fromisoformat('2023-05-05T09:06:43.699068') diff --git a/gufe/tokenization.py b/gufe/tokenization.py index a4118ac3..06ed7f4c 100644 --- a/gufe/tokenization.py +++ b/gufe/tokenization.py @@ -9,10 +9,13 @@ import inspect import json import logging +import networkx as nx import re -import weakref import warnings -from typing import Any, Union +import weakref +from itertools import chain +from typing import Any, Union, List, Tuple, Dict, Generator +from typing_extensions import Self from gufe.custom_codecs import ( BYTES_CODEC, @@ -638,6 +641,195 @@ def token(self) -> str: return self.split('-')[1] +def gufe_objects_from_shallow_dict( + obj: Union[List, Dict, GufeTokenizable] +) -> List[GufeTokenizable]: + """Find GufeTokenizables within a shallow dict. + + This function recursively looks through the list/dict structures encoding + GufeTokenizables and returns list of all GufeTokenizables found + within those structures, which may be potentially nested. + + Parameters + ---------- + obj + The input data structure to recursively traverse. For the initial call + of this function, this should be the shallow dict of a GufeTokenizable. + Input of a GufeTokenizable will immediately return a base case. + + Returns + ------- + List[GufeTokenizable] + All GufeTokenizables found in the shallow dict representation of a + GufeTokenizable. + + """ + if isinstance(obj, GufeTokenizable): + return [obj] + + elif isinstance(obj, list): + return list( + chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj]) + ) + + elif isinstance(obj, dict): + return list( + chain.from_iterable( + [gufe_objects_from_shallow_dict(item) for item in obj.values()] + ) + ) + + return [] + + +def gufe_to_digraph(gufe_obj): + """Recursively construct a DiGraph from a GufeTokenizable. + + The DiGraph encodes the dependency structure of the GufeTokenizable on + other GufeTokenizables. + """ + graph = nx.DiGraph() + shallow_dicts = {} + + def add_edges(o): + # if we've made a shallow dict before, we've already added this one + # and all its dependencies; return `None` to avoid going down the tree + # again + sd = shallow_dicts.get(o.key) + if sd is not None: + return None + + # if not, then we make the shallow dict only once, add it to our index, + # add edges to dependencies, and return it so we continue down the tree + sd = o.to_shallow_dict() + + shallow_dicts[o.key] = sd + + # add the object node in case there aren't any connections + graph.add_node(o) + connections = gufe_objects_from_shallow_dict(sd) + + for c in connections: + graph.add_edge(o, c) + + return sd + + sd = add_edges(gufe_obj) + _ = modify_dependencies(sd, add_edges, is_gufe_obj, mode="encode") + + return graph + + +class KeyedChain(object): + """Keyed chain representation encoder of a GufeTokenizable. + + The keyed chain representation of a GufeTokenizable provides a + topologically sorted list of gufe keys and GufeTokenizable keyed dicts + that can be used to fully recreate a GufeTokenizable without the need for a + populated TOKENIZATION_REGISTRY. + + The class wraps around a list of tuples containing the gufe key and the + keyed dict form of the GufeTokenizable. + + Examples + -------- + We can create a keyed chain representation from any GufeTokenizable, such + as: + + >>> from gufe.tokenization import KeyedChain + >>> s = SolventComponent() + >>> keyed_chain = KeyedChain.gufe_to_keyed_chain_rep(s) + >>> keyed_chain + [('SolventComponent-26b4034ad9dbd9f908dfc298ea8d449f', + {'smiles': 'O', + 'positive_ion': 'Na+', + 'negative_ion': 'Cl-', + 'ion_concentration': '0.15 molar', + 'neutralize': True, + '__qualname__': 'SolventComponent', + '__module__': 'gufe.components.solventcomponent', + ':version:': 1})] + + And we can do the reverse operation as well to go from a keyed chain + representation back to a GufeTokenizable: + + >>> KeyedChain(keyed_chain).to_gufe() + SolventComponent(name=O, Na+, Cl-) + + """ + + def __init__(self, keyed_chain): + self._keyed_chain = keyed_chain + + @classmethod + def from_gufe(cls, gufe_object: GufeTokenizable) -> Self: + """Initialize a KeyedChain from a GufeTokenizable.""" + return cls(cls.gufe_to_keyed_chain_rep(gufe_object)) + + def to_gufe(self) -> GufeTokenizable: + """Initialize a GufeTokenizable.""" + gts: Dict[str, GufeTokenizable] = {} + for gufe_key, keyed_dict in self: + gt = key_decode_dependencies(keyed_dict, registry=gts) + gts[gufe_key] = gt + return gt + + @classmethod + def from_keyed_chain_rep(cls, keyed_chain: List[Tuple[str, Dict]]) -> Self: + """Initialize a KeyedChain from a keyed chain representation.""" + return cls(keyed_chain) + + def to_keyed_chain_rep(self) -> List[Tuple[str, Dict]]: + """Return the keyed chain representation of this object.""" + return list(self) + + @staticmethod + def gufe_to_keyed_chain_rep( + gufe_object: GufeTokenizable, + ) -> List[Tuple[str, Dict]]: + """Create the keyed chain representation of a GufeTokenizable. + + This represents the GufeTokenizable as a list of two-element tuples + containing, as their first and second elements, the gufe key and keyed + dict form of the GufeTokenizable, respectively, and provides the + underlying structure used in the KeyedChain class. + + Parameters + ---------- + gufe_object + The GufeTokenizable for which the KeyedChain is generated. + + Returns + ------- + key_and_keyed_dicts + The keyed chain representation of a GufeTokenizable. + + """ + key_and_keyed_dicts = [ + (str(gt.key), gt.to_keyed_dict()) + for gt in nx.topological_sort(gufe_to_digraph(gufe_object)) + ][::-1] + return key_and_keyed_dicts + + def gufe_keys(self) -> Generator[str, None, None]: + """Create a generator that iterates over the gufe keys in the KeyedChain.""" + for key, _ in self: + yield key + + def keyed_dicts(self) -> Generator[Dict, None, None]: + """Create a generator that iterates over the keyed dicts in the KeyedChain.""" + for _, _dict in self: + yield _dict + + def __len__(self): + return len(self._keyed_chain) + + def __iter__(self): + return self._keyed_chain.__iter__() + + def __getitem__(self, index): + return self._keyed_chain[index] + # TOKENIZABLE_REGISTRY: Dict[str, weakref.ref[GufeTokenizable]] = {} TOKENIZABLE_REGISTRY: weakref.WeakValueDictionary[str, GufeTokenizable] = weakref.WeakValueDictionary()