Skip to content

Commit

Permalink
Merge pull request #4 from gruenewald-lab/write_cgsmiles
Browse files Browse the repository at this point in the history
[5] Writer
  • Loading branch information
fgrunewald authored Oct 8, 2024
2 parents 4430b6b + 116b5e9 commit 3465bea
Show file tree
Hide file tree
Showing 5 changed files with 322 additions and 6 deletions.
6 changes: 3 additions & 3 deletions cgsmiles/graph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def sort_nodes_by_attr(graph,
nx.Graph
graph with nodes sorted in correct order
"""
attr_values = nx.get_node_attributes(graph, sort_attr)
sorted_ids = sorted(attr_values, key=lambda item: (attr_values[item], item))
mapping = {old: new for new, old in enumerate(sorted_ids)}
fragids = nx.get_node_attributes(graph, sort_attr)
sorted_ids = sorted(fragids.items(), key=lambda item: (item[1], item[0]))
mapping = {old[0]: new for new, old in enumerate(sorted_ids)}
new_graph = nx.relabel_nodes(graph, mapping, copy=True)
for attr, is_list in relative_attr:
attr_dict = nx.get_node_attributes(new_graph, attr)
Expand Down
3 changes: 2 additions & 1 deletion cgsmiles/read_fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def strip_bonding_descriptors(fragment_string):
elif current_order:
order = current_order
current_order = None
# we need to remove the symbol from the clean string
smile = smile[:-1]
else:
order = 1
bonding_descrpt[prev_node].append(bond_descrp + str(order))
Expand Down Expand Up @@ -243,7 +245,6 @@ def fragment_iter(fragment_str, all_atom=True):
ez_isomer_class = {idx: val[-1] for idx, val in ez_isomers.items()}
nx.set_node_attributes(mol_graph, ez_isomer_atoms, 'ez_isomer_atoms')
nx.set_node_attributes(mol_graph, ez_isomer_class, 'ez_isomer_class')
print("hcount", mol_graph.nodes(data='hcount'))
# we deal with a CG resolution graph
else:
mol_graph = read_cgsmiles(smile)
Expand Down
4 changes: 2 additions & 2 deletions cgsmiles/tests/test_cgsmile_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ def test_read_cgsmiles(smile, nodes, edges, orders):
None),
# \ fragment split
("[>]CC(\F)=[<]",
"CC(F)=",
"CC(F)",
{0: [">1"], 1: ["<2"]},
None,
{2: (2, 1, '\\')}),
# / fragment split
("[>]CC(/F)=[<]",
"CC(F)=",
"CC(F)",
{0: [">1"], 1: ["<2"]},
None,
{2: (2, 1, '/')}),
Expand Down
73 changes: 73 additions & 0 deletions cgsmiles/tests/test_write_cgsmiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
import networkx as nx
from pysmiles.testhelper import assertEqualGraphs
from cgsmiles.read_fragments import read_fragments
from cgsmiles.read_cgsmiles import read_cgsmiles
from cgsmiles.write_cgsmiles import (write_cgsmiles_fragments,
write_cgsmiles_graph,
write_cgsmiles)
from cgsmiles import MoleculeResolver

@pytest.mark.parametrize('input_string',(
# smiple linear seqeunce
"{#PEO=[$]COC[$],#OHter=[$]O}",
# two bonding IDs
"{#PEO=[$][$A]COC[$][$B],#OHter=[$]O}",
# something with bond order
"{#PEO=[$]=COC[$A],#OHter=[$A]O,#PI=[$]=C}",
# something with a shash operator
"{#TC5=[!]CCC[!],#TN6a=[!]CNC[!]}",
# something with aromatic fragments
"{#TC5=[!]ccc[!],#TN6a=[!]cnc[!]}",
))
def test_write_fragments(input_string):
frag_dict = read_fragments(input_string)
out_string = write_cgsmiles_fragments(frag_dict, smiles_format=True)
frag_dict_out = read_fragments(out_string)
assert set(frag_dict_out) == set(frag_dict)
for fragname in frag_dict:
assertEqualGraphs(frag_dict_out[fragname], frag_dict[fragname])

@pytest.mark.parametrize('input_string',(
# smiple linear seqeunce
"{[#PEO][#PMA]}",
# ring
"{[#TC5]1[#TC5][#TC5][#TC5][#TC5]1}",
# branched
"{[#PE][#PMA]([#PEO][#PEO][#PEO])[#PE]}",
# branched nested
"{[#PE][#PMA]([#PEO][#PEO]([#OMA][#OMA]1[#OMA][#OMA]1))[#PE]}",
# special cycle
"{[#PE]1[#PMA]1}",
# special triple cycle
"{[#A]12[#B]12}",
))
def test_write_mol_graphs(input_string):
mol_graph = read_cgsmiles(input_string)
out_string = write_cgsmiles_graph(mol_graph)
out_graph = read_cgsmiles(out_string)
assertEqualGraphs(mol_graph, out_graph)

@pytest.mark.parametrize('input_string',(
# smiple linear seqeunce
"{[#PEO][#PMMA][#PEO][#PMMA]}.{#PEO=[>]COC[<],#PMMA=[>]CC(C)[<]C(=O)OC}",
# something with ring
"{[#TC5]1[#TC5][#TC5]1}.{#TC5=[$]cc[$]}",))
def test_write_cgsmiles(input_string):
resolver = MoleculeResolver.from_string(input_string)
fragment_dicts = resolver.fragment_dicts
molecule = resolver.molecule
output_string = write_cgsmiles(molecule, fragment_dicts)
out_resolver = MoleculeResolver.from_string(output_string)
out_mol = out_resolver.molecule
assertEqualGraphs(molecule, out_mol)
out_fragments = out_resolver.fragment_dicts
assert len(fragment_dicts) == len(out_fragments)
for frag_dict, frag_dict_out in zip(fragment_dicts, out_fragments):
assert set(frag_dict_out) == set(frag_dict)
for fragname in frag_dict:
# we cannot be sure that the atomnames are the same because they
# will depend on the order
nx.set_node_attributes(frag_dict_out[fragname], None, "atomname")
nx.set_node_attributes(frag_dict[fragname], None, "atomname")
assertEqualGraphs(frag_dict_out[fragname], frag_dict[fragname])
242 changes: 242 additions & 0 deletions cgsmiles/write_cgsmiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import logging
from collections import defaultdict
import networkx as nx
from pysmiles.smiles_helper import format_atom
from pysmiles.write_smiles import _get_ring_marker,_write_edge_symbol

logger = logging.getLogger(__name__)

order_to_symbol = {0: '.', 1: '-', 1.5: ':', 2: '=', 3: '#', 4: '$'}

def format_node(molecule, current):
"""
Format a node from a `molecule` graph according to
the CGSmiles syntax. The attribute fragname has to
be set for the `current` node.
Parameters
----------
molecule: nx.Graph
current: abc.hashbale
Returns
-------
str
the formatted string
"""
node = "[#{}]".format(molecule.nodes[current]['fragname'])
return node

def format_bonding(bonding):
"""
Given the list of bonding descriptors format them
such that they can be added after a node/atom. This
function wraps the descriptor in [ ] braces and makes
sure that the bond order annotation is removed.
Parameters
----------
bonding: list[str]
list of bonding descriptors
Returns
-------
str
the formatted bonding descriptor string
"""
bond_str = ""
for bonding_descrpt in bonding:
bond_order = bonding_descrpt[-1]
order_symb = order_to_symbol[int(bond_order)]
if order_symb != '-':
bond_str = order_symb
bond_str += "["+str(bonding_descrpt[:-1])+"]"
return bond_str

def write_graph(molecule, smiles_format=False, default_element='*'):
"""
Creates a CGsmiles string describing `molecule`.
`molecule` should be a single connected component.
Parameters
----------
molecule : nx.Graph
The molecule for which a CGsmiles string should be generated.
smiles_format:
If the nodes are written using the OpenSmiles standard format.
Returns
-------
str
The CGSmiles string describing `molecule`.
"""
start = min(molecule)
dfs_successors = nx.dfs_successors(molecule, source=start)

predecessors = defaultdict(list)
for node_key, successors in dfs_successors.items():
for successor in successors:
predecessors[successor].append(node_key)
predecessors = dict(predecessors)
# We need to figure out which edges we won't cross when doing the dfs.
# These are the edges we'll need to add to the smiles using ring markers.
edges = set()
for n_idx, n_jdxs in dfs_successors.items():
for n_jdx in n_jdxs:
edges.add(frozenset((n_idx, n_jdx)))
total_edges = set(map(frozenset, molecule.edges))
ring_edges = list(total_edges - edges)
# in cgsmiles graphs only bonds of order 1 and 2
# exists; order 2 means we have a ring at the
# higher resolution. These orders are therefore
# represented as rings and that requires to
# add them to the ring list
if not smiles_format:
for edge in molecule.edges:
if molecule.edges[edge]['order'] != 1:
for n in range(1, molecule.edges[edge]['order']):
ring_edges.append(frozenset(edge))

atom_to_ring_idx = defaultdict(list)
ring_idx_to_bond = {}
ring_idx_to_marker = {}
for ring_idx, (n_idx, n_jdx) in enumerate(ring_edges, 1):
atom_to_ring_idx[n_idx].append(ring_idx)
atom_to_ring_idx[n_jdx].append(ring_idx)
ring_idx_to_bond[ring_idx] = (n_idx, n_jdx)

branch_depth = 0
branches = set()
to_visit = [start]
smiles = ''

while to_visit:
current = to_visit.pop()
if current in branches:
branch_depth += 1
smiles += '('
branches.remove(current)

if current in predecessors:
# It's not the first atom we're visiting, so we want to see if the
# edge we last crossed to get here is interesting.
previous = predecessors[current]
assert len(previous) == 1
previous = previous[0]
if smiles_format and _write_edge_symbol(molecule, previous, current):
order = molecule.edges[previous, current].get('order', 1)
smiles += order_to_symbol[order]

if smiles_format:
smiles += format_atom(molecule, current, default_element)
else:
smiles += format_node(molecule, current)

# we add the bonding descriptors if there are any
if molecule.nodes[current].get('bonding', False):
smiles += format_bonding(molecule.nodes[current]['bonding'])

if current in atom_to_ring_idx:
# We're going to need to write a ring number
ring_idxs = atom_to_ring_idx[current]
for ring_idx in ring_idxs:
ring_bond = ring_idx_to_bond[ring_idx]
if ring_idx not in ring_idx_to_marker:
marker = _get_ring_marker(ring_idx_to_marker.values())
ring_idx_to_marker[ring_idx] = marker
new_marker = True
else:
marker = ring_idx_to_marker.pop(ring_idx)
new_marker = False

if smiles_format and _write_edge_symbol(molecule, *ring_bond) and new_marker:
order = molecule.edges[ring_bond].get('order', 1)
smiles += order_to_symbol[order]

smiles += str(marker) if marker < 10 else '%{}'.format(marker)

if current in dfs_successors:
# Proceed to the next node in this branch
next_nodes = dfs_successors[current]
# ... and if needed, remember to return here later
branches.update(next_nodes[1:])
to_visit.extend(next_nodes)
elif branch_depth:
# We're finished with this branch.
smiles += ')'
branch_depth -= 1

smiles += ')' * branch_depth
return smiles

def write_cgsmiles_graph(molecule):
"""
Write a CGSmiles graph sans fragments at
different resolution.
Parameters
----------
molecule: nx.Graph
a molecule where each node as a fragname attribute
that is used as name in the CGSmiles string.
Returns
-------
str
the CGSmiles string
"""

cgsmiles_str = write_graph(molecule)
return "{" + cgsmiles_str + "}"

def write_cgsmiles_fragments(fragment_dict, smiles_format=True):
"""
Write fragments of molecule graph. To identify the fragments
all nodes with the same `fragname` and `fragid` attributes
are considered as fragment. Bonding between fragments is
extracted from the `bonding` edge attributes.
Parameters
----------
fragment_dict: dict[str, nx.Graph]
a dict of fragment graphs
smiles_format: bool
write all atom SMILES if True (default) otherwise
write CGSmiles
Returns
-------
str
"""
fragment_str = ""
for fragname, frag_graph in fragment_dict.items():
fragment_str += f"#{fragname}="
# format graph depending on resolution
fragment_str += write_graph(frag_graph, smiles_format=smiles_format) + ","
fragment_str = "{" + fragment_str[:-1] + "}"
return fragment_str

def write_cgsmiles(molecule_graph, fragments, last_all_atom=True):
"""
Write a CGSmiles string given a low resolution molecule graph
and any number of higher resolutions provided as fragment dicts.
Parameters
----------
molecule_graph: nx.Graph
fragments: list[dict[nx.Graph]]
a list of fragment dicts
last_all_atom: bool
if the last set of fragments is at the all_atom level
Returns
-------
str
CGSmiles string
"""
final_str = write_cgsmiles_graph(molecule_graph)
for layer, fragment in enumerate(fragments):
all_atom = (layer == len(fragments)-1) and last_all_atom
fragment_str = write_cgsmiles_fragments(fragment, smiles_format=all_atom)
final_str += "." + fragment_str
return final_str

0 comments on commit 3465bea

Please sign in to comment.