Skip to content

Commit

Permalink
[fmt] Guesser (#4851)
Browse files Browse the repository at this point in the history
  • Loading branch information
RMeli authored Dec 24, 2024
1 parent 55cce24 commit a10e23e
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 233 deletions.
20 changes: 13 additions & 7 deletions package/MDAnalysis/guesser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ class FooGuesser(GuesserBase):
.. versionadded:: 2.8.0
"""

def __init__(cls, name, bases, classdict):
type.__init__(type, name, bases, classdict)

_GUESSERS[classdict['context'].upper()] = cls
_GUESSERS[classdict["context"].upper()] = cls


class GuesserBase(metaclass=_GuesserMeta):
Expand All @@ -87,7 +88,8 @@ class GuesserBase(metaclass=_GuesserMeta):
.. versionadded:: 2.8.0
"""
context = 'base'

context = "base"
_guesser_methods: Dict = {}

def __init__(self, universe=None, **kwargs):
Expand Down Expand Up @@ -149,8 +151,10 @@ def guess_attr(self, attr_to_guess, force_guess=False):
try:
guesser_method = self._guesser_methods[attr_to_guess]
except KeyError:
raise ValueError(f'{type(self).__name__} cannot guess this '
f'attribute: {attr_to_guess}')
raise ValueError(
f"{type(self).__name__} cannot guess this "
f"attribute: {attr_to_guess}"
)

# Connection attributes should be just returned as they are always
# appended to the Universe. ``force_guess`` handling should happen
Expand All @@ -161,7 +165,8 @@ def guess_attr(self, attr_to_guess, force_guess=False):
# check if the topology already has the attribute to partially guess it
if hasattr(self._universe.atoms, attr_to_guess) and not force_guess:
attr_values = np.array(
getattr(self._universe.atoms, attr_to_guess, None))
getattr(self._universe.atoms, attr_to_guess, None)
)

empty_values = top_attr.are_values_missing(attr_values)

Expand All @@ -175,8 +180,9 @@ def guess_attr(self, attr_to_guess, force_guess=False):

else:
logger.info(
f'There is no empty {attr_to_guess} values. Guesser did '
f'not guess any new values for {attr_to_guess} attribute')
f"There is no empty {attr_to_guess} values. Guesser did "
f"not guess any new values for {attr_to_guess} attribute"
)
return None
else:
return np.array(guesser_method())
Expand Down
153 changes: 91 additions & 62 deletions package/MDAnalysis/guesser/default_guesser.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class DefaultGuesser(GuesserBase):
.. versionadded:: 2.8.0
"""
context = 'default'

context = "default"

def __init__(
self,
Expand All @@ -170,25 +171,25 @@ def __init__(
vdwradii=None,
fudge_factor=0.55,
lower_bound=0.1,
**kwargs
**kwargs,
):
super().__init__(
universe,
box=box,
vdwradii=vdwradii,
fudge_factor=fudge_factor,
lower_bound=lower_bound,
**kwargs
**kwargs,
)
self._guesser_methods = {
'masses': self.guess_masses,
'types': self.guess_types,
'elements': self.guess_types,
'bonds': self.guess_bonds,
'angles': self.guess_angles,
'dihedrals': self.guess_dihedrals,
'impropers': self.guess_improper_dihedrals,
'aromaticities': self.guess_aromaticities,
"masses": self.guess_masses,
"types": self.guess_types,
"elements": self.guess_types,
"bonds": self.guess_bonds,
"angles": self.guess_angles,
"dihedrals": self.guess_dihedrals,
"impropers": self.guess_improper_dihedrals,
"aromaticities": self.guess_aromaticities,
}

def guess_masses(self, atom_types=None, indices_to_guess=None):
Expand Down Expand Up @@ -225,18 +226,21 @@ def guess_masses(self, atom_types=None, indices_to_guess=None):
except NoDataError:
try:
atom_types = self.guess_types(
atom_types=self._universe.atoms.names)
atom_types=self._universe.atoms.names
)
except NoDataError:
raise NoDataError(
"there is no reference attributes"
" (elements, types, or names)"
" in this universe to guess mass from") from None
" in this universe to guess mass from"
) from None

if indices_to_guess is not None:
atom_types = atom_types[indices_to_guess]

masses = np.array([self.get_atom_mass(atom)
for atom in atom_types], dtype=np.float64)
masses = np.array(
[self.get_atom_mass(atom) for atom in atom_types], dtype=np.float64
)
return masses

def get_atom_mass(self, element):
Expand All @@ -256,7 +260,8 @@ def get_atom_mass(self, element):
"Unknown masses are set to 0.0 for current version, "
"this will be deprecated in version 3.0.0 and replaced by"
" Masse's no_value_label (np.nan)",
PendingDeprecationWarning)
PendingDeprecationWarning,
)
return 0.0

def guess_atom_mass(self, atomname):
Expand Down Expand Up @@ -295,13 +300,16 @@ def guess_types(self, atom_types=None, indices_to_guess=None):
except NoDataError:
raise NoDataError(
"there is no reference attributes in this universe "
"to guess types from") from None
"to guess types from"
) from None

if indices_to_guess is not None:
atom_types = atom_types[indices_to_guess]

return np.array([self.guess_atom_element(atom)
for atom in atom_types], dtype=object)
return np.array(
[self.guess_atom_element(atom) for atom in atom_types],
dtype=object,
)

def guess_atom_element(self, atomname):
"""Guess the element of the atom from the name.
Expand All @@ -315,7 +323,7 @@ def guess_atom_element(self, atomname):
still not found, we iteratively continue to remove the last character
or first character until we find a match. If ultimately no match
is found, the first character of the stripped name is returned.
If the input name is an empty string, an empty string is returned.
The table comes from CHARMM and AMBER atom
Expand All @@ -331,16 +339,16 @@ def guess_atom_element(self, atomname):
:func:`guess_atom_type`
:mod:`MDAnalysis.guesser.tables`
"""
NUMBERS = re.compile(r'[0-9]') # match numbers
SYMBOLS = re.compile(r'[*+-]') # match *, +, -
if atomname == '':
return ''
NUMBERS = re.compile(r"[0-9]") # match numbers
SYMBOLS = re.compile(r"[*+-]") # match *, +, -
if atomname == "":
return ""
try:
return tables.atomelements[atomname.upper()]
except KeyError:
# strip symbols and numbers
no_symbols = re.sub(SYMBOLS, '', atomname)
name = re.sub(NUMBERS, '', no_symbols).upper()
no_symbols = re.sub(SYMBOLS, "", atomname)
name = re.sub(NUMBERS, "", no_symbols).upper()

# just in case
if name in tables.atomelements:
Expand Down Expand Up @@ -393,7 +401,7 @@ def guess_bonds(self, atoms=None, coords=None):
Raises
------
:exc:`ValueError`
:exc:`ValueError`
If inputs are malformed or `vdwradii` data is missing.
Expand All @@ -410,32 +418,37 @@ def guess_bonds(self, atoms=None, coords=None):
if len(atoms) != len(coords):
raise ValueError("'atoms' and 'coord' must be the same length")

fudge_factor = self._kwargs.get('fudge_factor', 0.55)
fudge_factor = self._kwargs.get("fudge_factor", 0.55)

# so I don't permanently change it
vdwradii = tables.vdwradii.copy()
user_vdwradii = self._kwargs.get('vdwradii', None)
user_vdwradii = self._kwargs.get("vdwradii", None)
# this should make algo use their values over defaults
if user_vdwradii:
vdwradii.update(user_vdwradii)

# Try using types, then elements
if hasattr(atoms, 'types'):
if hasattr(atoms, "types"):
atomtypes = atoms.types
else:
atomtypes = self.guess_types(atom_types=atoms.names)

# check that all types have a defined vdw
if not all(val in vdwradii for val in set(atomtypes)):
raise ValueError(("vdw radii for types: " +
", ".join([t for t in set(atomtypes) if
t not in vdwradii]) +
". These can be defined manually using the" +
f" keyword 'vdwradii'"))
raise ValueError(
(
"vdw radii for types: "
+ ", ".join(
[t for t in set(atomtypes) if t not in vdwradii]
)
+ ". These can be defined manually using the"
+ f" keyword 'vdwradii'"
)
)

lower_bound = self._kwargs.get('lower_bound', 0.1)
lower_bound = self._kwargs.get("lower_bound", 0.1)

box = self._kwargs.get('box', None)
box = self._kwargs.get("box", None)

if box is not None:
box = np.asarray(box)
Expand All @@ -447,14 +460,14 @@ def guess_bonds(self, atoms=None, coords=None):

bonds = []

pairs, dist = distances.self_capped_distance(coords,
max_cutoff=2.0 * max_vdw,
min_cutoff=lower_bound,
box=box)
pairs, dist = distances.self_capped_distance(
coords, max_cutoff=2.0 * max_vdw, min_cutoff=lower_bound, box=box
)
for idx, (i, j) in enumerate(pairs):
d = (vdwradii[atomtypes[i]] +
vdwradii[atomtypes[j]]) * fudge_factor
if (dist[idx] < d):
d = (
vdwradii[atomtypes[i]] + vdwradii[atomtypes[j]]
) * fudge_factor
if dist[idx] < d:
bonds.append((atoms[i].index, atoms[j].index))
return tuple(bonds)

Expand All @@ -480,18 +493,21 @@ def guess_angles(self, bonds=None):
--------
:meth:`guess_bonds`
"""
"""
from ..core.universe import Universe

angles_found = set()

if bonds is None:
if hasattr(self._universe.atoms, 'bonds'):
if hasattr(self._universe.atoms, "bonds"):
bonds = self._universe.atoms.bonds
else:
temp_u = Universe.empty(n_atoms=len(self._universe.atoms))
temp_u.add_bonds(self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions))
temp_u.add_bonds(
self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions
)
)
bonds = temp_u.atoms.bonds

for b in bonds:
Expand All @@ -501,7 +517,8 @@ def guess_angles(self, bonds=None):
if other_b != b: # if not the same bond I start as
third_a = other_b.partner(atom)
desc = tuple(
[other_a.index, atom.index, third_a.index])
[other_a.index, atom.index, third_a.index]
)
# first index always less than last
if desc[0] > desc[-1]:
desc = desc[::-1]
Expand Down Expand Up @@ -530,15 +547,18 @@ def guess_dihedrals(self, angles=None):
from ..core.universe import Universe

if angles is None:
if hasattr(self._universe.atoms, 'angles'):
if hasattr(self._universe.atoms, "angles"):
angles = self._universe.atoms.angles

else:
temp_u = Universe.empty(n_atoms=len(self._universe.atoms))

temp_u.add_bonds(self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions))

temp_u.add_bonds(
self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions
)
)

temp_u.add_angles(self.guess_angles(temp_u.atoms.bonds))

angles = temp_u.atoms.angles
Expand All @@ -549,8 +569,9 @@ def guess_dihedrals(self, angles=None):
a_tup = tuple([a.index for a in b]) # angle as tuple of numbers
# if searching with b[0], want tuple of (b[2], b[1], b[0], +new)
# search the first and last atom of each angle
for atom, prefix in zip([b.atoms[0], b.atoms[-1]],
[a_tup[::-1], a_tup]):
for atom, prefix in zip(
[b.atoms[0], b.atoms[-1]], [a_tup[::-1], a_tup]
):
for other_b in atom.bonds:
if not other_b.partner(atom) in b:
third_a = other_b.partner(atom)
Expand Down Expand Up @@ -580,14 +601,17 @@ def guess_improper_dihedrals(self, angles=None):
from ..core.universe import Universe

if angles is None:
if hasattr(self._universe.atoms, 'angles'):
if hasattr(self._universe.atoms, "angles"):
angles = self._universe.atoms.angles

else:
temp_u = Universe.empty(n_atoms=len(self._universe.atoms))

temp_u.add_bonds(self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions))
temp_u.add_bonds(
self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions
)
)

temp_u.add_angles(self.guess_angles(temp_u.atoms.bonds))

Expand Down Expand Up @@ -652,7 +676,12 @@ def guess_gasteiger_charges(self, atomgroup):

mol = atomgroup.convert_to("RDKIT")
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges

ComputeGasteigerCharges(mol, throwOnParamFailure=True)
return np.array([atom.GetDoubleProp("_GasteigerCharge")
for atom in mol.GetAtoms()],
dtype=np.float32)
return np.array(
[
atom.GetDoubleProp("_GasteigerCharge")
for atom in mol.GetAtoms()
],
dtype=np.float32,
)
1 change: 1 addition & 0 deletions package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ tables\.py
| MDAnalysis/visualization/.*\.py
| MDAnalysis/lib/.*\.py^
| MDAnalysis/transformations/.*\.py
| MDAnalysis/guesser/.*\.py
| MDAnalysis/converters/.*\.py
| MDAnalysis/selections/.*\.py
)
Expand Down
Loading

0 comments on commit a10e23e

Please sign in to comment.