Skip to content

Commit

Permalink
Merge pull request marrink-lab#589 from csbrasnett/lazy-merge
Browse files Browse the repository at this point in the history
Lazy merge
  • Loading branch information
pckroon authored May 1, 2024
2 parents 7e359bd + b85f024 commit e9bf1c9
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 11 deletions.
23 changes: 18 additions & 5 deletions bin/martinize2
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ def entry():
default=False,
help="Write separate topologies for identical chains",
)
file_group.add_argument(
chain_merging = file_group.add_argument(
"-merge",
dest="merge_chains",
type=lambda x: x.split(","),
type=str,
action="append",
help="Merge chains: e.g. -merge A,B,C (+)",
help="Merge chains: either a comma separated list of chains to merge e.g. -merge A,B,C (+), or -merge all\n"
"if instead all chains in the input file should be merged.\n"
"Can be given multiple times for different groups of chains to merge.",
)
file_group.add_argument(
"-resid",
Expand Down Expand Up @@ -974,8 +976,19 @@ def entry():
itp_paths = []
# Merge chains if required.
if args.merge_chains:
for chain_set in args.merge_chains:
vermouth.MergeChains(chain_set).run_system(system)
#if "all" is not in the list of chains to be merged
if "all" not in args.merge_chains:
input_chain_sets = [i.split(",") for i in args.merge_chains]
for chain_set in input_chain_sets:
vermouth.MergeChains(chains=chain_set, all_chains=False).run_system(system)
#if "all" is in the list and is the only argument
elif "all" in args.merge_chains and len(args.merge_chains) == 1:
vermouth.MergeChains(chains=[], all_chains=True).run_system(system)
#otherwise there are multiple arguments and we need to raise an ArgumentError
else:
raise argparse.ArgumentError(chain_merging,
message=("Multiple conflicting merging arguments given. "
"Either specify -merge all or -merge A,B,C (+)."))
vermouth.NameMolType(deduplicate=not args.keep_duplicate_itp).run_system(system)
defines = ()

Expand Down
29 changes: 23 additions & 6 deletions vermouth/processors/merge_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

from ..molecule import Molecule
from ..processors.processor import Processor
from ..log_helpers import StyleAdapter, get_logger
LOGGER = StyleAdapter(get_logger(__name__))


def merge_chains(system, chains):
def merge_chains(system, chains, all_chains):
"""
Merge molecules with the given chains as a single molecule.
Expand All @@ -42,15 +44,29 @@ def merge_chains(system, chains):
The system to modify.
chains: list[str]
A container of chain identifier.
all_chains: bool
If True, all chains will be merged.
"""
chains = set(chains)
if not all_chains and len(chains) > 0:
_chains = set(chains)
elif all_chains and len(chains) == 0:
_chains = set()
for molecule in system.molecules:
# Molecules can contain multiple chains
_chains.update(node.get('chain') for node in molecule.nodes.values())
else:
raise ValueError("Can specify specific chains or all chains, but not both")

if any(not c for c in _chains):
LOGGER.warning('One or more of your chains does not have a chain identifier in input file.')

merged = Molecule()
merged._force_field = system.force_field
has_merged = False
new_molecules = []
for molecule in system.molecules:
molecule_chains = set(node.get('chain') for node in molecule.nodes.values())
if molecule_chains.issubset(chains):
if molecule_chains.issubset(_chains):
if not has_merged:
merged.nrexcl = molecule.nrexcl
new_molecules.append(merged)
Expand All @@ -65,8 +81,9 @@ def merge_chains(system, chains):
class MergeChains(Processor):
name = 'MergeChains'

def __init__(self, chains):
self.chains = chains
def __init__(self, chains=None, all_chains=False):
self.chains = chains or []
self.all_chains = all_chains

def run_system(self, system):
merge_chains(system, self.chains)
merge_chains(system, self.chains, self.all_chains)
123 changes: 123 additions & 0 deletions vermouth/tests/test_merge_chains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright 2018 University of Groningen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains unittests for vermouth.processors.merge_chains.
"""

import networkx as nx
import pytest
from vermouth.system import System
from vermouth.molecule import Molecule
from vermouth.forcefield import ForceField
from vermouth.processors.merge_chains import (
MergeChains
)
from vermouth.tests.datafiles import (
FF_UNIVERSAL_TEST,
)

@pytest.mark.parametrize('node_data, edge_data, merger, expected', [
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": ["A", "B"], "all_chains": False},
False
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": [], "all_chains": True},
False
),
(
[
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': None, 'resname': 'ALA', 'resid': 1},
{'chain': None, 'resname': 'ALA', 'resid': 2},
{'chain': None, 'resname': 'ALA', 'resid': 3}
],
[(0, 1), (1, 2), (3, 4), (4, 5)],
{"chains": [], "all_chains": True},
True
),
])
def test_merge(caplog, node_data, edge_data, merger, expected):
"""
Tests that the merging works as expected.
"""
system = System(force_field=ForceField(FF_UNIVERSAL_TEST))
mol = Molecule(force_field=system.force_field)
mol.add_nodes_from(enumerate(node_data))
mol.add_edges_from(edge_data)

mols = nx.connected_components(mol)
for nodes in mols:
system.add_molecule(mol.subgraph(nodes))

processor = MergeChains(**merger)
caplog.clear()
processor.run_system(system)

if expected:
assert any(rec.levelname == 'WARNING' for rec in caplog.records)
else:
assert caplog.records == []

def test_too_many_args():
"""
Tests that error is raised when too many arguments are given.
"""
node_data = [
{'chain': 'A', 'resname': 'ALA', 'resid': 1},
{'chain': 'A', 'resname': 'ALA', 'resid': 2},
{'chain': 'A', 'resname': 'ALA', 'resid': 3},
{'chain': 'B', 'resname': 'ALA', 'resid': 1},
{'chain': 'B', 'resname': 'ALA', 'resid': 2},
{'chain': 'B', 'resname': 'ALA', 'resid': 3}
]
edge_data = [(0, 1), (1, 2), (3, 4), (4, 5)]

system = System(force_field=ForceField(FF_UNIVERSAL_TEST))
mol = Molecule(force_field=system.force_field)
mol.add_nodes_from(enumerate(node_data))
mol.add_edges_from(edge_data)

mols = nx.connected_components(mol)
for nodes in mols:
system.add_molecule(mol.subgraph(nodes))

merger = {"chains": ["A", "B"], "all_chains": True}

processor = MergeChains(**merger)

with pytest.raises(ValueError):
processor.run_system(system)

0 comments on commit e9bf1c9

Please sign in to comment.