Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sampler #17

Merged
merged 33 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
16af54a
init draft for sampler
fgrunewald Jul 9, 2024
77a58ed
make deepcopy when merging
fgrunewald Jul 9, 2024
e937860
update docstrings
fgrunewald Jul 10, 2024
aebf84c
address comments and refactor slightly
fgrunewald Jul 12, 2024
6220711
add seeds
fgrunewald Jul 12, 2024
c5ba1e3
refactor random seed
fgrunewald Jul 23, 2024
6842eb4
address style comments
fgrunewald Jul 23, 2024
f4f148a
Merge branch 'master' into sampler
fgrunewald Jul 23, 2024
6a4951d
adjust open bonds function to optionally take target nodes
fgrunewald Jul 24, 2024
eb62ea9
when annontating atomnames make meta_graph optional; otherwise go off…
fgrunewald Jul 24, 2024
705691b
the fragid for cgsmiles fragments should be 0 in agreement with the c…
fgrunewald Jul 24, 2024
6fec0bc
update function call set_atom_names_atomistic according to new args
fgrunewald Jul 24, 2024
5fbb282
refactor sample
fgrunewald Jul 24, 2024
eff05ef
add more tests
fgrunewald Jul 24, 2024
e345cc7
update handling of terminal addition
fgrunewald Jul 26, 2024
f883d19
finalize tests for init function
fgrunewald Jul 26, 2024
c662435
update sampler
fgrunewald Aug 14, 2024
1228ce7
Merge branch 'master' into sampler
fgrunewald Aug 29, 2024
0eaccc6
expose sampler
fgrunewald Aug 29, 2024
654fb9b
keep proper track of fragid when adding terminals
fgrunewald Aug 29, 2024
6dfdd6f
update sampler and change meaning of bonding operators
fgrunewald Sep 10, 2024
c378034
change meaning of bonding operators and fix in test
fgrunewald Sep 10, 2024
8f50245
change naming in sampler and update doc strings
fgrunewald Sep 10, 2024
efa1a8a
update docstring
fgrunewald Sep 10, 2024
ccd2901
address comments
fgrunewald Sep 10, 2024
716ebe8
update docstring
fgrunewald Sep 10, 2024
d241a96
fix doscstrings
fgrunewald Sep 12, 2024
286a161
fix doscstrings
fgrunewald Sep 12, 2024
4c062bb
fix doscstrings and spelling
fgrunewald Sep 12, 2024
d9e6a32
typo sphinx link
fgrunewald Sep 16, 2024
7d49154
update tests
fgrunewald Sep 18, 2024
4993722
update tests and remove print
fgrunewald Sep 18, 2024
9cb3a45
Update cgsmiles/sample.py
fgrunewald Sep 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cgsmiles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
from .read_cgsmiles import read_cgsmiles
from .read_fragments import read_fragments
from .resolve import MoleculeResolver
from .sample import MoleculeSampler
10 changes: 7 additions & 3 deletions cgsmiles/cgsmiles_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@ def find_complementary_bonding_descriptor(bonding_descriptor):
compl = bonding_descriptor
return compl

def find_open_bonds(molecule):
def find_open_bonds(molecule, target_nodes=None):
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved
"""
Collect all nodes which have an open bonding descriptor and store
them as keys with a list of nodes as values.
"""
if target_nodes is None:
target_nodes = list(molecule.nodes)
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved

open_bonds_by_descriptor = defaultdict(list)
open_bonds = nx.get_node_attributes(molecule, 'bonding')
for node, bonding_types in open_bonds.items():
for bonding_types in bonding_types:
open_bonds_by_descriptor[bonding_types].append(node)
if node in target_nodes:
for bonding_types in bonding_types:
open_bonds_by_descriptor[bonding_types].append(node)
return open_bonds_by_descriptor
29 changes: 23 additions & 6 deletions cgsmiles/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,32 @@ def annotate_fragments(meta_graph, molecule):
return meta_graph


def set_atom_names_atomistic(meta_graph, molecule):
def set_atom_names_atomistic(molecule, meta_graph=None):
"""
Set atomnames according to commonly used convention
in molecular dynamics (MD) forcefields. This convention
is defined as element plus counter for atom in residue.

Parameters
----------
molecule: nx.Graph
the molecule for which to adjust the atomnames
meta_graph: nx.Graph
optional; get the fragments from the meta_graph
attributes which is faster in some cases
"""
for meta_node in meta_graph.nodes:
fraggraph = meta_graph.nodes[meta_node]['graph']
for idx, node in enumerate(fraggraph.nodes):
atomname = fraggraph.nodes[node]['element'] + str(idx)
fraggraph.nodes[node]['atomname'] = atomname
fraglist = defaultdict(list)
if meta_graph:
for meta_node in meta_graph.nodes:
fraggraph = meta_graph.nodes[meta_node]['graph']
fraglist[meta_node] += list(fraggraph.nodes)
else:
node_to_fragid = nx.get_node_attributes(molecule, 'fragid')
for node, fragids in node_to_fragid.items():
assert len(fragids) == 1
fraglist[fragids[0]].append(node)

for fragnodes in fraglist.values():
for idx, node in enumerate(fragnodes):
atomname = molecule.nodes[node]['element'] + str(idx)
molecule.nodes[node]['atomname'] = atomname
7 changes: 6 additions & 1 deletion cgsmiles/pysmiles_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import networkx as nx
import pysmiles

def compute_mass(molecule):
def compute_mass(input_molecule):
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute the mass of a molecule from the PTE.

Expand All @@ -15,6 +15,11 @@ def compute_mass(molecule):
float
the atomic mass
"""
molecule = input_molecule.copy()
print(molecule.nodes(data=True))
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved
# we need to add the hydrogen atoms
# for computing the mass
rebuild_h_atoms(molecule)
mass = 0
for node in molecule.nodes:
element = molecule.nodes[node]['element']
Expand Down
2 changes: 1 addition & 1 deletion cgsmiles/read_fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ def fragment_iter(fragment_str, all_atom=True):
# we deal with a CG resolution graph
else:
mol_graph = read_cgsmiles(smile)
nx.set_node_attributes(mol_graph, 1, 'fragid')
fragnames = nx.get_node_attributes(mol_graph, 'fragname')
nx.set_node_attributes(mol_graph, fragnames, 'atomname')
nx.set_node_attributes(mol_graph, bonding_descrpt, 'bonding')
Expand All @@ -161,6 +160,7 @@ def fragment_iter(fragment_str, all_atom=True):
nx.set_node_attributes(mol_graph, atomnames, 'atomname')

nx.set_node_attributes(mol_graph, fragname, 'fragname')
nx.set_node_attributes(mol_graph, 0, 'fragid')
yield fragname, mol_graph

def read_fragments(fragment_str, all_atom=True, fragment_dict=None):
Expand Down
2 changes: 1 addition & 1 deletion cgsmiles/resolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def resolve(self):
# in all-atom MD there are common naming conventions
# that might be expected and hence we set them here
if all_atom:
set_atom_names_atomistic(self.meta_graph, self.molecule)
set_atom_names_atomistic(self.molecule, self.meta_graph)

# increment the resolution counter
self.resolution_counter += 1
Expand Down
133 changes: 86 additions & 47 deletions cgsmiles/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ class MoleculeSampler:
"""
def __init__(self,
fragment_dict,
target_weight,
bonding_probabilities,
branch_term_probs=None,
terminal_fragments=[],
bond_term_probs=None,
fragment_masses=None,
termination_probabilities=None,
start=None,
all_atom=True):
all_atom=True,
seed=None):
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved

"""
Parameters
Expand All @@ -39,31 +40,31 @@ def __init__(self,
masses of the molecule fragments; if all_atom is True
these can be left out and are automatically computed from
the element masses
termination_probabilities: dict[str, float]
probability that a fragment is a chain terminal, which means
all descriptors will be removed from that fragment and
terminate chain growth. No additional terminal residue is
specified; this should be used for coarse-grained polymers
without specific end-group.
branch_term_probs: dict[str, float]
probability that a branched fragment is a chain terminal;
if terminal_probabilities are given
bond_term_probs: dict[str, float]
probability that a certain bonding descriptor connection
is present at the terminal
start: str
fragment name of the fragment to start with
all_atom: bool
if the fragments are all-atom resolution
seed: int
set random seed for all processes; default is None
"""
# first initalize the random number generator
random.seed(a=seed)
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved
self.fragment_dict = fragment_dict
self.bonding_probabilities = bonding_probabilities
self.termination_probabilities = termination_probabilities
self.branch_term_probs = branch_term_probs
self.bond_term_probs = bond_term_probs
self.all_atom = all_atom
self.target_weight = target_weight
self.start = start
self.current_open_bonds = defaultdict(list)
self.current_weight = 0

# we need to make sure that we have the molecular
# masses so we can compute the target weight
self.fragments_by_bonding = defaultdict(list)
self.terminals_by_bonding = defaultdict(list)
if fragment_masses:
guess_mass_from_PTE = False
self.fragment_masses = fragment_masses
Expand All @@ -85,9 +86,12 @@ def __init__(self,
bondings = nx.get_node_attributes(fraggraph, "bonding")
for node, bondings in bondings.items():
for bonding in bondings:
self.fragments_by_bonding[bonding].append((fragname, node))
if fragname in terminal_fragments:
self.terminals_by_bonding[bonding].append((fragname, node))
else:
self.fragments_by_bonding[bonding].append((fragname, node))

def grow_chain(self, molecule, seed=None):
def add_fragment(self, molecule, open_bonds, fragments, bonding_probabilities):
"""
Pick an open bonding descriptor according to `bonding_probabilities`
and then pick a fragment that has the complementory bonding descriptor.
Expand All @@ -96,6 +100,13 @@ def grow_chain(self, molecule, seed=None):
----------
molecule: nx.Graph
the molecule to extend
open_bonds: dict[list[abc.hashable]]
a dict of bonding active descriptors with list of nodes
in molecule as value
fragments: dict[list[str]]
a dict of fragment names indexed by their bonding descriptors
bonding_probabilities:
the porbabilities that bonding connector forms a bond

Returns
-------
Expand All @@ -106,30 +117,58 @@ def grow_chain(self, molecule, seed=None):
"""
# 1. get the probabilties of any bonding descriptor on the chain to
# form the new bond
probs = np.array([self.bonding_probabilities[bond_type[:-1]] for bond_type in self.current_open_bonds])
probs = np.array([bonding_probabilities[bond_type[:-1]] for bond_type in open_bonds])
probs = probs / sum(probs)
# 2. pick a random bonding descriptor according to these probs
bonding = np.random.choice(list(self.current_open_bonds.keys()), p=probs)
bonding = random.choices(list(open_bonds.keys()), weights=probs)[0]
# 3. get a corresponding node; it may be that one descriptor is found on
# several nodes
random.seed(a=seed)
source_node = random.choice(self.current_open_bonds[bonding])
source_node = random.choice(open_bonds[bonding])
# 4. get the complementary matching bonding descriptor
compl_bonding = find_complementary_bonding_descriptor(bonding)
# 5. pick a new fragment that has such bonding descriptor
random.seed(a=seed)
fragname, target_node = random.choice(self.fragments_by_bonding[compl_bonding])
fragname, target_node = random.choice(fragments[compl_bonding])
# 6. add the new fragment and do some book-keeping
correspondence = merge_graphs(molecule, self.fragment_dict[fragname])
molecule.add_edge(source_node,
correspondence[target_node],
bonding=(bonding, compl_bonding))
bonding=(bonding, compl_bonding),
order = int(bonding[-1]))
molecule.nodes[source_node]['bonding'].remove(bonding)
molecule.nodes[correspondence[target_node]]['bonding'].remove(compl_bonding)
self.current_open_bonds = find_open_bonds(molecule)
return molecule, fragname

def terminate_branch(self, molecule, fragname, fragid, seed=None):
def terminate_fragment(self, molecule, fragid):
"""
If bonding probabilities for terminal residues are given
select one terminal to add to the given fragment. If no
terminal bonding probabilities are defined the active bonding
descriptors of all nodes will be removed.

Parameters
----------
molecule: nx.Graph
the molecule graph
fragid: int
the id of the fragment
"""
target_nodes = [node for node in molecule.nodes if fragid in molecule.nodes[node]['fragid']]
fgrunewald marked this conversation as resolved.
Show resolved Hide resolved
open_bonds = find_open_bonds(molecule, target_nodes=target_nodes)
# if terminal fragment bonding probabilties are given; add them here
if self.bond_term_probs:
self.add_fragment(molecule,
open_bonds,
self.terminals_by_bonding,
self.bond_term_probs)
fragid += 1

for node in target_nodes:
if 'bonding' in molecule.nodes[node]:
del molecule.nodes[node]['bonding']

return fragid

def terminate_branch(self, molecule, fragname, fragid):
"""
Probabilistically terminate a branch by removing all
bonding descriptors from the last fragment.
Expand All @@ -147,58 +186,59 @@ def terminate_branch(self, molecule, fragname, fragid, seed=None):
-------
nx.Graph
"""
term_prob = self.termination_probabilities.get(fragname, 0)
random.seed(a=seed)
term_prob = self.branch_term_probs.get(fragname, -1)
# probability check for termination
if random.random() < term_prob:
if random.random() <= term_prob:
# check if there are more open bonding descriptors
# if the number is the same as would get removed
# then we are not on a branch
active_bonds = nx.get_node_attributes(molecule, 'bonding')
target_nodes = [ node for node in active_bonds if molecule.nodes[node]['fragid'] == fragid]
target_nodes = [node for node in active_bonds if fragid in molecule.nodes[node]['fragid']]
if len(target_nodes) < len(active_bonds):
for node in target_nodes:
del molecule.nodes[node]['bonding']
self.current_open_bonds = find_open_bonds(molecule)
return molecule
fragid = self.terminate_fragment(molecule, fragid)
return molecule, fragid

def sample(self, target_weight, seed=None):
def sample(self, target_weight, start_fragment=None):
"""
From a list of cgsmiles fragment graphs generate a new random molecule
according by stitching them together.

Parameters
----------
target_weight
target_weight: int
the weight of the polymer to generate
start_fragment: str
the fragment name to start with

Returns
-------
nx.Graph
the graph of the molecule
"""
molecule = nx.Graph()
if self.start:
fragment = self.fragment_dict[self.start]
if start_fragment:
fragment = self.fragment_dict[start_fragment]
else:
# intialize the molecule; all fragements have the same probability
random.seed(a=seed)
fragname = random.choice(list(self.fragment_dict.keys()))
fragment = self.fragment_dict[fragname]

merge_graphs(molecule, fragment)
self.current_open_bonds = find_open_bonds(molecule)
open_bonds = find_open_bonds(molecule)

current_weight = 0

# next we add monomers one after the other
fragid = 1
while current_weight < target_weight:
molecule, fragname = self.grow_chain(molecule, seed=seed)
if self.termination_probabilities:
molecule = self.terminate_branch(molecule, fragname, fragid, seed=seed)
fragid += 1
open_bonds = find_open_bonds(molecule)
molecule, fragname = self.add_fragment(molecule,
open_bonds,
self.fragments_by_bonding,
self.bonding_probabilities)
molecule, fragid = self.terminate_branch(molecule, fragname, fragid)
current_weight += self.fragment_masses[fragname]
fragid += 1

if self.all_atom:
rebuild_h_atoms(molecule)
Expand All @@ -208,8 +248,8 @@ def sample(self, target_weight, seed=None):

# in all-atom MD there are common naming conventions
# that might be expected and hence we set them here
# if self.all_atom:
# set_atom_names_atomistic(self.molecule)
if self.all_atom:
set_atom_names_atomistic(molecule)

return molecule

Expand Down Expand Up @@ -239,7 +279,6 @@ def from_fragment_string(cls,
all_atom = kwargs.get('all_atom', True)
fragment_dict = read_fragments(fragment_strings[0],
all_atom=all_atom)

sampler = cls(fragment_dict,
**kwargs)

Expand Down
Loading
Loading