Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fmt] Guesser #4851

Merged
merged 3 commits into from
Dec 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions package/MDAnalysis/guesser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ class FooGuesser(GuesserBase):
.. versionadded:: 2.8.0
"""

def __init__(cls, name, bases, classdict):
type.__init__(type, name, bases, classdict)

_GUESSERS[classdict['context'].upper()] = cls
_GUESSERS[classdict["context"].upper()] = cls


class GuesserBase(metaclass=_GuesserMeta):
Expand All @@ -87,7 +88,8 @@ class GuesserBase(metaclass=_GuesserMeta):
.. versionadded:: 2.8.0
"""
context = 'base'

context = "base"
_guesser_methods: Dict = {}

def __init__(self, universe=None, **kwargs):
Expand Down Expand Up @@ -149,8 +151,10 @@ def guess_attr(self, attr_to_guess, force_guess=False):
try:
guesser_method = self._guesser_methods[attr_to_guess]
except KeyError:
raise ValueError(f'{type(self).__name__} cannot guess this '
f'attribute: {attr_to_guess}')
raise ValueError(
f"{type(self).__name__} cannot guess this "
f"attribute: {attr_to_guess}"
)

# Connection attributes should be just returned as they are always
# appended to the Universe. ``force_guess`` handling should happen
Expand All @@ -161,7 +165,8 @@ def guess_attr(self, attr_to_guess, force_guess=False):
# check if the topology already has the attribute to partially guess it
if hasattr(self._universe.atoms, attr_to_guess) and not force_guess:
attr_values = np.array(
getattr(self._universe.atoms, attr_to_guess, None))
getattr(self._universe.atoms, attr_to_guess, None)
)

empty_values = top_attr.are_values_missing(attr_values)

Expand All @@ -175,8 +180,9 @@ def guess_attr(self, attr_to_guess, force_guess=False):

else:
logger.info(
f'There is no empty {attr_to_guess} values. Guesser did '
f'not guess any new values for {attr_to_guess} attribute')
f"There is no empty {attr_to_guess} values. Guesser did "
f"not guess any new values for {attr_to_guess} attribute"
)
return None
else:
return np.array(guesser_method())
Expand Down
153 changes: 91 additions & 62 deletions package/MDAnalysis/guesser/default_guesser.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@
.. versionadded:: 2.8.0

"""
context = 'default'

context = "default"

def __init__(
self,
Expand All @@ -170,25 +171,25 @@
vdwradii=None,
fudge_factor=0.55,
lower_bound=0.1,
**kwargs
**kwargs,
):
super().__init__(
universe,
box=box,
vdwradii=vdwradii,
fudge_factor=fudge_factor,
lower_bound=lower_bound,
**kwargs
**kwargs,
)
self._guesser_methods = {
'masses': self.guess_masses,
'types': self.guess_types,
'elements': self.guess_types,
'bonds': self.guess_bonds,
'angles': self.guess_angles,
'dihedrals': self.guess_dihedrals,
'impropers': self.guess_improper_dihedrals,
'aromaticities': self.guess_aromaticities,
"masses": self.guess_masses,
"types": self.guess_types,
"elements": self.guess_types,
"bonds": self.guess_bonds,
"angles": self.guess_angles,
"dihedrals": self.guess_dihedrals,
"impropers": self.guess_improper_dihedrals,
"aromaticities": self.guess_aromaticities,
}

def guess_masses(self, atom_types=None, indices_to_guess=None):
Expand Down Expand Up @@ -225,18 +226,21 @@
except NoDataError:
try:
atom_types = self.guess_types(
atom_types=self._universe.atoms.names)
atom_types=self._universe.atoms.names
)
except NoDataError:
raise NoDataError(
"there is no reference attributes"
" (elements, types, or names)"
" in this universe to guess mass from") from None
" in this universe to guess mass from"
) from None

if indices_to_guess is not None:
atom_types = atom_types[indices_to_guess]

masses = np.array([self.get_atom_mass(atom)
for atom in atom_types], dtype=np.float64)
masses = np.array(
[self.get_atom_mass(atom) for atom in atom_types], dtype=np.float64
)
return masses

def get_atom_mass(self, element):
Expand All @@ -256,7 +260,8 @@
"Unknown masses are set to 0.0 for current version, "
"this will be deprecated in version 3.0.0 and replaced by"
" Masse's no_value_label (np.nan)",
PendingDeprecationWarning)
PendingDeprecationWarning,
)
return 0.0

def guess_atom_mass(self, atomname):
Expand Down Expand Up @@ -295,13 +300,16 @@
except NoDataError:
raise NoDataError(
"there is no reference attributes in this universe "
"to guess types from") from None
"to guess types from"
) from None

if indices_to_guess is not None:
atom_types = atom_types[indices_to_guess]

return np.array([self.guess_atom_element(atom)
for atom in atom_types], dtype=object)
return np.array(
[self.guess_atom_element(atom) for atom in atom_types],
dtype=object,
)

def guess_atom_element(self, atomname):
"""Guess the element of the atom from the name.
Expand All @@ -315,7 +323,7 @@
still not found, we iteratively continue to remove the last character
or first character until we find a match. If ultimately no match
is found, the first character of the stripped name is returned.

If the input name is an empty string, an empty string is returned.

The table comes from CHARMM and AMBER atom
Expand All @@ -331,16 +339,16 @@
:func:`guess_atom_type`
:mod:`MDAnalysis.guesser.tables`
"""
NUMBERS = re.compile(r'[0-9]') # match numbers
SYMBOLS = re.compile(r'[*+-]') # match *, +, -
if atomname == '':
return ''
NUMBERS = re.compile(r"[0-9]") # match numbers
SYMBOLS = re.compile(r"[*+-]") # match *, +, -
if atomname == "":
return ""
try:
return tables.atomelements[atomname.upper()]
except KeyError:
# strip symbols and numbers
no_symbols = re.sub(SYMBOLS, '', atomname)
name = re.sub(NUMBERS, '', no_symbols).upper()
no_symbols = re.sub(SYMBOLS, "", atomname)
name = re.sub(NUMBERS, "", no_symbols).upper()

# just in case
if name in tables.atomelements:
Expand Down Expand Up @@ -393,7 +401,7 @@

Raises
------
:exc:`ValueError`
:exc:`ValueError`
If inputs are malformed or `vdwradii` data is missing.


Expand All @@ -410,32 +418,37 @@
if len(atoms) != len(coords):
raise ValueError("'atoms' and 'coord' must be the same length")

fudge_factor = self._kwargs.get('fudge_factor', 0.55)
fudge_factor = self._kwargs.get("fudge_factor", 0.55)

# so I don't permanently change it
vdwradii = tables.vdwradii.copy()
user_vdwradii = self._kwargs.get('vdwradii', None)
user_vdwradii = self._kwargs.get("vdwradii", None)
# this should make algo use their values over defaults
if user_vdwradii:
vdwradii.update(user_vdwradii)

# Try using types, then elements
if hasattr(atoms, 'types'):
if hasattr(atoms, "types"):
atomtypes = atoms.types
else:
atomtypes = self.guess_types(atom_types=atoms.names)

# check that all types have a defined vdw
if not all(val in vdwradii for val in set(atomtypes)):
raise ValueError(("vdw radii for types: " +
", ".join([t for t in set(atomtypes) if
t not in vdwradii]) +
". These can be defined manually using the" +
f" keyword 'vdwradii'"))
raise ValueError(
(
"vdw radii for types: "
+ ", ".join(
[t for t in set(atomtypes) if t not in vdwradii]
)
+ ". These can be defined manually using the"
+ f" keyword 'vdwradii'"
)
)

lower_bound = self._kwargs.get('lower_bound', 0.1)
lower_bound = self._kwargs.get("lower_bound", 0.1)

box = self._kwargs.get('box', None)
box = self._kwargs.get("box", None)

if box is not None:
box = np.asarray(box)
Expand All @@ -447,14 +460,14 @@

bonds = []

pairs, dist = distances.self_capped_distance(coords,
max_cutoff=2.0 * max_vdw,
min_cutoff=lower_bound,
box=box)
pairs, dist = distances.self_capped_distance(
coords, max_cutoff=2.0 * max_vdw, min_cutoff=lower_bound, box=box
)
for idx, (i, j) in enumerate(pairs):
d = (vdwradii[atomtypes[i]] +
vdwradii[atomtypes[j]]) * fudge_factor
if (dist[idx] < d):
d = (
vdwradii[atomtypes[i]] + vdwradii[atomtypes[j]]
) * fudge_factor
if dist[idx] < d:
bonds.append((atoms[i].index, atoms[j].index))
return tuple(bonds)

Expand All @@ -480,18 +493,21 @@
--------
:meth:`guess_bonds`

"""
"""
from ..core.universe import Universe

angles_found = set()

if bonds is None:
if hasattr(self._universe.atoms, 'bonds'):
if hasattr(self._universe.atoms, "bonds"):
bonds = self._universe.atoms.bonds
else:
temp_u = Universe.empty(n_atoms=len(self._universe.atoms))
temp_u.add_bonds(self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions))
temp_u.add_bonds(
self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions
)
)
bonds = temp_u.atoms.bonds

for b in bonds:
Expand All @@ -501,7 +517,8 @@
if other_b != b: # if not the same bond I start as
third_a = other_b.partner(atom)
desc = tuple(
[other_a.index, atom.index, third_a.index])
[other_a.index, atom.index, third_a.index]
)
# first index always less than last
if desc[0] > desc[-1]:
desc = desc[::-1]
Expand Down Expand Up @@ -530,15 +547,18 @@
from ..core.universe import Universe

if angles is None:
if hasattr(self._universe.atoms, 'angles'):
if hasattr(self._universe.atoms, "angles"):
angles = self._universe.atoms.angles

else:
temp_u = Universe.empty(n_atoms=len(self._universe.atoms))

temp_u.add_bonds(self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions))

temp_u.add_bonds(
self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions
)
)

temp_u.add_angles(self.guess_angles(temp_u.atoms.bonds))

angles = temp_u.atoms.angles
Expand All @@ -549,8 +569,9 @@
a_tup = tuple([a.index for a in b]) # angle as tuple of numbers
# if searching with b[0], want tuple of (b[2], b[1], b[0], +new)
# search the first and last atom of each angle
for atom, prefix in zip([b.atoms[0], b.atoms[-1]],
[a_tup[::-1], a_tup]):
for atom, prefix in zip(
[b.atoms[0], b.atoms[-1]], [a_tup[::-1], a_tup]
):
for other_b in atom.bonds:
if not other_b.partner(atom) in b:
third_a = other_b.partner(atom)
Expand Down Expand Up @@ -580,14 +601,17 @@
from ..core.universe import Universe

if angles is None:
if hasattr(self._universe.atoms, 'angles'):
if hasattr(self._universe.atoms, "angles"):
angles = self._universe.atoms.angles

else:
temp_u = Universe.empty(n_atoms=len(self._universe.atoms))

temp_u.add_bonds(self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions))
temp_u.add_bonds(
self.guess_bonds(
self._universe.atoms, self._universe.atoms.positions
)
)

temp_u.add_angles(self.guess_angles(temp_u.atoms.bonds))

Expand Down Expand Up @@ -652,7 +676,12 @@

mol = atomgroup.convert_to("RDKIT")
from rdkit.Chem.rdPartialCharges import ComputeGasteigerCharges

ComputeGasteigerCharges(mol, throwOnParamFailure=True)
return np.array([atom.GetDoubleProp("_GasteigerCharge")
for atom in mol.GetAtoms()],
dtype=np.float32)
return np.array(

Check warning on line 681 in package/MDAnalysis/guesser/default_guesser.py

View check run for this annotation

Codecov / codecov/patch

package/MDAnalysis/guesser/default_guesser.py#L681

Added line #L681 was not covered by tests
[
atom.GetDoubleProp("_GasteigerCharge")
for atom in mol.GetAtoms()
],
dtype=np.float32,
)
1 change: 1 addition & 0 deletions package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ tables\.py
| MDAnalysis/visualization/.*\.py
| MDAnalysis/lib/.*\.py^
| MDAnalysis/transformations/.*\.py
| MDAnalysis/guesser/.*\.py
)
'''
extend-exclude = '''
Expand Down
Loading
Loading