Skip to content

Commit

Permalink
finalized details and added some more docs
Browse files Browse the repository at this point in the history
It still doesn't work fully. But the basis should be there.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Oct 11, 2024
1 parent ec57db1 commit 55ca9a0
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 41 deletions.
32 changes: 24 additions & 8 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,36 +635,52 @@ def wrapped_method(*args, **kwargs):
assign_nested_attribute(cls, method_path, wrapped_method)

# I don't really see why this is required?
# head, tail, *_ = method_path.split(".")
# setattr(
# getattr(cls, head), tail, wrapped_method,
# )
head, tail, *_ = method_path.split(".")
setattr(
getattr(cls, head),
tail,
wrapped_method,
)


def assign_class_dispatcher_methods(
cls: object,
dispatcher_name: str,
dispatcher_name: Union[str, tuple[str, str]],
signature_add_self: bool = False,
as_attributes: bool = False,
):
"""Document all methods in a dispatcher class as nested methods in the owner class."""

if isinstance(dispatcher_name, str):
dispatcher_name = (dispatcher_name, "dispatch")

dispatcher_name, method_name = dispatcher_name
dispatcher = getattr(cls, dispatcher_name)

_log.info("assign_class_dispatcher_methods found dispatcher: {dispatcher}")
for key, method in dispatcher._dispatchs.items():
if not isinstance(key, str):
# TODO do not know yet what to do with object types used as extractions
continue

if method_name is None:
dispatch = method
else:
dispatch = getattr(method, method_name)

path = f"{dispatcher_name}.{key}"
_log.info("assign_class_dispatcher_methods assigning attribute: {path}")
# if dispatcher_name == "new":
# print(cls, dispatcher_name, path, method, dispatch, dispatch.__doc__)
# if dispatcher_name == "to":
# print(cls, dispatcher_name, path, method, dispatch, dispatch.__doc__)
if as_attributes:
assign_nested_attribute(cls, path, method.dispatch)
assign_nested_attribute(cls, path, dispatch)
else:
assign_nested_method(
cls,
path,
method.dispatch,
dispatch,
signature_add_self=signature_add_self,
)

Expand Down Expand Up @@ -748,7 +764,7 @@ def yield_types(obj: object, classes):

for name, attr in yield_types(obj, sisl._dispatcher.AbstractDispatcher):
# Fix the class dispatchers methods
assign_class_dispatcher_methods(obj, name, as_attributes=True)
assign_class_dispatcher_methods(obj, name, as_attributes=False)
# Collect all the different names where a dispatcher is associated.
# In this way we die if we add a new one, without documenting it!
_found_dispatch_attributes.add(name)
Expand Down
44 changes: 28 additions & 16 deletions src/sisl/_core/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4443,18 +4443,12 @@ def __call__(self, parser, ns, value, option_string=None):

# Define base-class for this
class GeometryNewDispatch(AbstractDispatch):
"""Base dispatcher from class passing arguments to Geometry class
This forwards all `__call__` calls to `dispatch`
"""

def __call__(self, *args, **kwargs):
return self.dispatch(*args, **kwargs)
"""Base dispatcher from class passing arguments to Geometry class"""


# Bypass regular Geometry to be returned as is
class GeometryNewGeometryDispatch(GeometryNewDispatch):
def dispatch(self, geometry, copy=False):
def dispatch(self, geometry, copy: bool = False) -> Geometry:
"""Return Geometry, for sanitization purposes"""
cls = self._get_class()
if cls != geometry.__class__:
Expand All @@ -4473,7 +4467,7 @@ def dispatch(self, geometry, copy=False):


class GeometryNewFileDispatch(GeometryNewDispatch):
def dispatch(self, *args, **kwargs):
def dispatch(self, *args, **kwargs) -> Geometry:
"""Defer the `Geometry.read` method by passing down arguments"""
cls = self._get_class()
return cls.read(*args, **kwargs)
Expand All @@ -4485,8 +4479,8 @@ def dispatch(self, *args, **kwargs):


class GeometryNewAseDispatch(GeometryNewDispatch):
def dispatch(self, aseg, **kwargs):
"""Convert an ``ase`` object into a `Geometry`"""
def dispatch(self, aseg, **kwargs) -> Geometry:
"""Convert an `ase.Atoms` object into a `Geometry`"""
cls = self._get_class()
Z = aseg.get_atomic_numbers()
xyz = aseg.get_positions()
Expand All @@ -4511,7 +4505,7 @@ def dispatch(self, aseg, **kwargs):


class GeometryNewpymatgenDispatch(GeometryNewDispatch):
def dispatch(self, struct, **kwargs):
def dispatch(self, struct, **kwargs) -> Geometry:
"""Convert a ``pymatgen`` structure/molecule object into a `Geometry`"""
from pymatgen.core import Structure

Expand Down Expand Up @@ -4556,7 +4550,16 @@ class GeometryToDispatch(AbstractDispatch):


class GeometryToSileDispatch(GeometryToDispatch):
def dispatch(self, *args, **kwargs):
def dispatch(self, *args, **kwargs) -> None:
"""Writes the geometry to a sile with any optional arguments.
Examples
--------
>>> geom = si.geom.graphene()
>>> geom.to("hello.xyz")
>>> geom.to(pathlib.Path("hello.xyz"))
"""
geom = self._get_object()
return geom.write(*args, **kwargs)

Expand All @@ -4569,7 +4572,8 @@ def dispatch(self, *args, **kwargs):


class GeometryToAseDispatch(GeometryToDispatch):
def dispatch(self, **kwargs):
def dispatch(self, **kwargs) -> ase.Atoms:
"""Conversion of `Geometry` to an `ase.Atoms` object"""
from ase import Atoms as ase_Atoms

geom = self._get_object()
Expand All @@ -4591,7 +4595,13 @@ def dispatch(self, **kwargs):


class GeometryTopymatgenDispatch(GeometryToDispatch):
def dispatch(self, **kwargs):
def dispatch(
self, **kwargs
) -> Union[pymatgen.core.Molecule, pymatgen.core.Structure]:
"""Conversion of `Geometry` to a `pymatgen` object.
Depending on the periodicity, it can be `Molecule` or `Structure`.
"""
from pymatgen.core import Lattice, Molecule, Structure

from sisl._core.atom import PeriodicTable
Expand All @@ -4615,7 +4625,9 @@ def dispatch(self, **kwargs):


class GeometryToDataframeDispatch(GeometryToDispatch):
def dispatch(self, *args, **kwargs):
def dispatch(self, *args, **kwargs) -> pandas.DataFrame:
"""Convert the geometry to a `pandas.DataFrame` with values stored in columns"""

import pandas as pd

geom = self._get_object()
Expand Down
48 changes: 32 additions & 16 deletions src/sisl/_core/lattice.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@


class BoundaryCondition(IntEnum):
"""Enum for boundary conditions"""

UNKNOWN = auto()
PERIODIC = auto()
DIRICHLET = auto()
Expand Down Expand Up @@ -1165,17 +1167,11 @@ def __setstate__(self, d):

# Define base-class for this
class LatticeNewDispatch(AbstractDispatch):
"""Base dispatcher from class passing arguments to Geometry class
This forwards all `__call__` calls to `dispatch`
"""

def __call__(self, *args, **kwargs):
return self.dispatch(*args, **kwargs)
"""Base dispatcher from class passing arguments to Lattice class"""


class LatticeNewLatticeDispatch(LatticeNewDispatch):
def dispatch(self, lattice, copy: bool = False):
def dispatch(self, lattice, copy: bool = False) -> Lattice:
"""Return Lattice as-is, for sanitization purposes"""
cls = self._get_class()
if cls != lattice.__class__:
Expand All @@ -1195,7 +1191,14 @@ def dispatch(self, lattice, copy: bool = False):


class LatticeNewListLikeDispatch(LatticeNewDispatch):
def dispatch(self, cell, *args, **kwargs):
def dispatch(self, cell, *args, **kwargs) -> Lattice:
"""Converts simple `array-like` variables to a `Lattice`
Examples
--------
>>> Lattice.new([1, 2, 3]) == Lattice([1, 2, 3])
"""
return Lattice(cell, *args, **kwargs)


Expand All @@ -1210,7 +1213,8 @@ def dispatch(self, cell, *args, **kwargs):


class LatticeNewAseDispatch(LatticeNewDispatch):
def dispatch(self, aseg):
def dispatch(self, aseg) -> Lattice:
"""`ase.Cell` conversion to `Lattice`"""
cls = self._get_class(allow_instance=True)
cell = aseg.get_cell()
nsc = [3 if pbc else 1 for pbc in aseg.pbc]
Expand All @@ -1234,7 +1238,7 @@ def dispatch(self, aseg):


class LatticeNewFileDispatch(LatticeNewDispatch):
def dispatch(self, *args, **kwargs):
def dispatch(self, *args, **kwargs) -> Lattice:
"""Defer the `Lattice.read` method by passing down arguments"""
cls = self._get_class()
return cls.read(*args, **kwargs)
Expand All @@ -1250,7 +1254,8 @@ class LatticeToDispatch(AbstractDispatch):


class LatticeToAseDispatch(LatticeToDispatch):
def dispatch(self, **kwargs):
def dispatch(self, **kwargs) -> ase.Cell:
"""`Lattice` conversion to an `ase.Cell` object."""
from ase import Cell as ase_Cell

lattice = self._get_object()
Expand All @@ -1261,7 +1266,16 @@ def dispatch(self, **kwargs):


class LatticeToSileDispatch(LatticeToDispatch):
def dispatch(self, *args, **kwargs):
def dispatch(self, *args, **kwargs) -> Any:
"""`Lattice` writing to a sile.
Examples
--------
>>> geom = si.geom.graphene()
>>> geom.lattice.to("hello.xyz")
>>> geom.lattice.to(pathlib.Path("hello.xyz"))
"""
lattice = self._get_object()
return lattice.write(*args, **kwargs)

Expand All @@ -1274,7 +1288,8 @@ def dispatch(self, *args, **kwargs):


class LatticeToCuboidDispatch(LatticeToDispatch):
def dispatch(self, center=None, origin=None, orthogonal=False):
def dispatch(self, center=None, origin=None, orthogonal=False) -> Cuboid:
"""Convert lattice parameters to a `Cuboid`"""
lattice = self._get_object()

cell = lattice.cell.copy()
Expand Down Expand Up @@ -1436,12 +1451,13 @@ def boundary_condition(self, boundary_condition: Sequence[BoundaryConditionType]

@property
def pbc(self) -> np.ndarray:
f"""{Lattice.pbc.__doc__}"""
__doc__ = Lattice.pbc.__doc__
return self.lattice.pbc


class LatticeNewLatticeChildDispatch(LatticeNewDispatch):
def dispatch(self, obj, copy: bool = False):
def dispatch(self, obj, copy: bool = False) -> Lattice:
"""Extraction of `Lattice` object from a `LatticeChild` object."""
# for sanitation purposes
if copy:
return obj.lattice.copy()
Expand Down
2 changes: 1 addition & 1 deletion src/sisl/_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def register(self, key, dispatch, default=False, overwrite=True, to_class=True):
if isinstance(cls_dispatch, ClassDispatcher):
cls_dispatch.register(key, dispatch, overwrite=overwrite)

def __call__(self, obj, *args, **kwargs):
def __call__(self, obj: Any, *args, **kwargs) -> Any:
# A call on a TypeDispatcher forces at least a single argument
# where the type is being dispatched.

Expand Down

0 comments on commit 55ca9a0

Please sign in to comment.