Skip to content

Commit

Permalink
Merge branch 'main' into off-models-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
IAlibay authored Oct 24, 2024
2 parents cfeede7 + e57f300 commit 15e0c9e
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
- name: "Setup Micromamba"
uses: mamba-org/setup-micromamba@v1
with:
micromamba-binary-path: ~/.local/bin/micromamba
environment-file: environment.yml
environment-name: gufe
cache-environment: true
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ jobs:
cache-environment: true
cache-downloads: true
create-args: >-
python=3.9
python=3.10
rdkit=2023.09.5
init-shell: bash

- name: "Install steps"
Expand Down
2 changes: 1 addition & 1 deletion docs/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ dependencies:
- sphinx
- openmm
- pip:
- git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@main
- git+https://github.com/OpenFreeEnergy/ofe-sphinx-theme@a45f3edd5bc3e973c1a01b577c71efa1b62a65d6
4 changes: 2 additions & 2 deletions gufe/tests/test_ligandatommapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def test_too_small_B(self, molA, molB):
class TestLigandAtomMapping(GufeTokenizableTestsMixin):
cls = LigandAtomMapping
repr = "LigandAtomMapping(componentA=SmallMoleculeComponent(name=), componentB=SmallMoleculeComponent(name=), componentA_to_componentB={0: 0, 1: 1}, annotations={'foo': 'bar'})"
key = "LigandAtomMapping-c333723fbbee702c641cb9dca9beae49"
key = "LigandAtomMapping-2c0aae226e3f69d2d1cf429abaefdb5b"

@pytest.fixture
def instance(self, annotated_simple_mapping):
Expand All @@ -289,4 +289,4 @@ def test_id_key(self, instance):
def test_keyed_dict(self, instance):
i2 = self.cls.from_dict(instance.to_dict())

assert instance.to_keyed_dict() == i2.to_keyed_dict()
assert instance.to_keyed_dict() == i2.to_keyed_dict()
2 changes: 1 addition & 1 deletion gufe/tests/test_smallmoleculecomponent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_ensure_ofe_name(internal, rdkit_name, name, expected, recwarn):
class TestSmallMoleculeComponent(GufeTokenizableTestsMixin):

cls = SmallMoleculeComponent
key = "SmallMoleculeComponent-51068a89f4793e688ee26135a9b7fbb6"
key = "SmallMoleculeComponent-82d90fcdcbe76a4155b0ea42b9080ff2"
repr = "SmallMoleculeComponent(name=ethane)"

@pytest.fixture
Expand Down
62 changes: 61 additions & 1 deletion gufe/tests/test_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from gufe.tokenization import (
GufeTokenizable, GufeKey, tokenize, TOKENIZABLE_REGISTRY,
import_qualname, get_class, TOKENIZABLE_CLASS_REGISTRY, JSON_HANDLER,
get_all_gufe_objs,
get_all_gufe_objs, gufe_to_digraph, gufe_objects_from_shallow_dict,
KeyedChain,
)


Expand Down Expand Up @@ -380,6 +381,65 @@ def test_token(self):
assert k.token == 'bar'


def test_gufe_to_digraph(solvated_complex):
graph = gufe_to_digraph(solvated_complex)

connected_objects = gufe_objects_from_shallow_dict(
solvated_complex.to_shallow_dict()
)

assert len(graph.nodes) == 4
assert len(graph.edges) == 3

for node_a, node_b in graph.edges:
assert node_b in connected_objects
assert node_a is solvated_complex


def test_gufe_objects_from_shallow_dict(solvated_complex):
shallow_dict = solvated_complex.to_shallow_dict()
gufe_objects = set(gufe_objects_from_shallow_dict(shallow_dict))

assert len(gufe_objects) == 3
assert set(gufe_objects) == set(solvated_complex.components.values())


class TestKeyedChain:

def test_from_gufe(self, benzene_variants_star_map):
contained_objects = list(get_all_gufe_objs(benzene_variants_star_map))
expected_len = len(contained_objects)

kc = KeyedChain.from_gufe(benzene_variants_star_map)

assert len(kc) == expected_len

original_keys = [obj.key for obj in contained_objects]
original_keyed_dicts = [
obj.to_keyed_dict() for obj in contained_objects
]

kc_gufe_keys = set(kc.gufe_keys())
kc_keyed_dicts = list(kc.keyed_dicts())

assert kc_gufe_keys == set(original_keys)

for key, keyed_dict in zip(original_keys, original_keyed_dicts):
assert key in kc_gufe_keys
assert keyed_dict in kc_keyed_dicts

def test_to_gufe(self, benzene_variants_star_map):
kc = KeyedChain.from_gufe(benzene_variants_star_map)
assert hash(kc.to_gufe()) == hash(benzene_variants_star_map)

def test_get_item(self, benzene_variants_star_map):
kc = KeyedChain.from_gufe(benzene_variants_star_map)

assert kc[0] == kc._keyed_chain[0]
assert kc[-1] == kc._keyed_chain[-1]
assert kc[:] == kc._keyed_chain[:]


def test_datetime_to_json():
d = datetime.datetime.fromisoformat('2023-05-05T09:06:43.699068')

Expand Down
196 changes: 194 additions & 2 deletions gufe/tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@
import inspect
import json
import logging
import networkx as nx
import re
import weakref
import warnings
from typing import Any, Union
import weakref
from itertools import chain
from typing import Any, Union, List, Tuple, Dict, Generator
from typing_extensions import Self

from gufe.custom_codecs import (
BYTES_CODEC,
Expand Down Expand Up @@ -638,6 +641,195 @@ def token(self) -> str:
return self.split('-')[1]


def gufe_objects_from_shallow_dict(
obj: Union[List, Dict, GufeTokenizable]
) -> List[GufeTokenizable]:
"""Find GufeTokenizables within a shallow dict.
This function recursively looks through the list/dict structures encoding
GufeTokenizables and returns list of all GufeTokenizables found
within those structures, which may be potentially nested.
Parameters
----------
obj
The input data structure to recursively traverse. For the initial call
of this function, this should be the shallow dict of a GufeTokenizable.
Input of a GufeTokenizable will immediately return a base case.
Returns
-------
List[GufeTokenizable]
All GufeTokenizables found in the shallow dict representation of a
GufeTokenizable.
"""
if isinstance(obj, GufeTokenizable):
return [obj]

elif isinstance(obj, list):
return list(
chain.from_iterable([gufe_objects_from_shallow_dict(item) for item in obj])
)

elif isinstance(obj, dict):
return list(
chain.from_iterable(
[gufe_objects_from_shallow_dict(item) for item in obj.values()]
)
)

return []


def gufe_to_digraph(gufe_obj):
"""Recursively construct a DiGraph from a GufeTokenizable.
The DiGraph encodes the dependency structure of the GufeTokenizable on
other GufeTokenizables.
"""
graph = nx.DiGraph()
shallow_dicts = {}

def add_edges(o):
# if we've made a shallow dict before, we've already added this one
# and all its dependencies; return `None` to avoid going down the tree
# again
sd = shallow_dicts.get(o.key)
if sd is not None:
return None

# if not, then we make the shallow dict only once, add it to our index,
# add edges to dependencies, and return it so we continue down the tree
sd = o.to_shallow_dict()

shallow_dicts[o.key] = sd

# add the object node in case there aren't any connections
graph.add_node(o)
connections = gufe_objects_from_shallow_dict(sd)

for c in connections:
graph.add_edge(o, c)

return sd

sd = add_edges(gufe_obj)
_ = modify_dependencies(sd, add_edges, is_gufe_obj, mode="encode")

return graph


class KeyedChain(object):
"""Keyed chain representation encoder of a GufeTokenizable.
The keyed chain representation of a GufeTokenizable provides a
topologically sorted list of gufe keys and GufeTokenizable keyed dicts
that can be used to fully recreate a GufeTokenizable without the need for a
populated TOKENIZATION_REGISTRY.
The class wraps around a list of tuples containing the gufe key and the
keyed dict form of the GufeTokenizable.
Examples
--------
We can create a keyed chain representation from any GufeTokenizable, such
as:
>>> from gufe.tokenization import KeyedChain
>>> s = SolventComponent()
>>> keyed_chain = KeyedChain.gufe_to_keyed_chain_rep(s)
>>> keyed_chain
[('SolventComponent-26b4034ad9dbd9f908dfc298ea8d449f',
{'smiles': 'O',
'positive_ion': 'Na+',
'negative_ion': 'Cl-',
'ion_concentration': '0.15 molar',
'neutralize': True,
'__qualname__': 'SolventComponent',
'__module__': 'gufe.components.solventcomponent',
':version:': 1})]
And we can do the reverse operation as well to go from a keyed chain
representation back to a GufeTokenizable:
>>> KeyedChain(keyed_chain).to_gufe()
SolventComponent(name=O, Na+, Cl-)
"""

def __init__(self, keyed_chain):
self._keyed_chain = keyed_chain

@classmethod
def from_gufe(cls, gufe_object: GufeTokenizable) -> Self:
"""Initialize a KeyedChain from a GufeTokenizable."""
return cls(cls.gufe_to_keyed_chain_rep(gufe_object))

def to_gufe(self) -> GufeTokenizable:
"""Initialize a GufeTokenizable."""
gts: Dict[str, GufeTokenizable] = {}
for gufe_key, keyed_dict in self:
gt = key_decode_dependencies(keyed_dict, registry=gts)
gts[gufe_key] = gt
return gt

@classmethod
def from_keyed_chain_rep(cls, keyed_chain: List[Tuple[str, Dict]]) -> Self:
"""Initialize a KeyedChain from a keyed chain representation."""
return cls(keyed_chain)

def to_keyed_chain_rep(self) -> List[Tuple[str, Dict]]:
"""Return the keyed chain representation of this object."""
return list(self)

@staticmethod
def gufe_to_keyed_chain_rep(
gufe_object: GufeTokenizable,
) -> List[Tuple[str, Dict]]:
"""Create the keyed chain representation of a GufeTokenizable.
This represents the GufeTokenizable as a list of two-element tuples
containing, as their first and second elements, the gufe key and keyed
dict form of the GufeTokenizable, respectively, and provides the
underlying structure used in the KeyedChain class.
Parameters
----------
gufe_object
The GufeTokenizable for which the KeyedChain is generated.
Returns
-------
key_and_keyed_dicts
The keyed chain representation of a GufeTokenizable.
"""
key_and_keyed_dicts = [
(str(gt.key), gt.to_keyed_dict())
for gt in nx.topological_sort(gufe_to_digraph(gufe_object))
][::-1]
return key_and_keyed_dicts

def gufe_keys(self) -> Generator[str, None, None]:
"""Create a generator that iterates over the gufe keys in the KeyedChain."""
for key, _ in self:
yield key

def keyed_dicts(self) -> Generator[Dict, None, None]:
"""Create a generator that iterates over the keyed dicts in the KeyedChain."""
for _, _dict in self:
yield _dict

def __len__(self):
return len(self._keyed_chain)

def __iter__(self):
return self._keyed_chain.__iter__()

def __getitem__(self, index):
return self._keyed_chain[index]


# TOKENIZABLE_REGISTRY: Dict[str, weakref.ref[GufeTokenizable]] = {}
TOKENIZABLE_REGISTRY: weakref.WeakValueDictionary[str, GufeTokenizable] = weakref.WeakValueDictionary()
Expand Down

0 comments on commit 15e0c9e

Please sign in to comment.