Skip to content

Commit

Permalink
added many more functions, also from Grid
Browse files Browse the repository at this point in the history
These accommodations required an upgrade of
the typing system.

Sadly there is a *bug*/*feature-request*
that broke >=3.9 with missing names.
It is related to this:
 https://bugs.python.org/issue41987

Basically the AtomsArgument can't be used
in the type-hinting because the ForwardRef lookup
tries to resolve it, regardless of whether we are
in a TYPE_CHECKING state or not. Hence it *must*
be defined before we can use it.

This happens, even if __future__ annotations
has been imported.

This forces us to define them via strings manually,
*sigh*.

Signed-off-by: Nick Papior <[email protected]>
  • Loading branch information
zerothi committed Jan 16, 2024
1 parent cce3ac3 commit 0816909
Show file tree
Hide file tree
Showing 10 changed files with 1,236 additions and 1,136 deletions.
753 changes: 738 additions & 15 deletions src/sisl/_core/_geometry_ufuncs.py

Large diffs are not rendered by default.

125 changes: 123 additions & 2 deletions src/sisl/_core/_grid_ufuncs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,140 @@
# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Sequence, Union

import numpy as np

import sisl._array as _a
from sisl._ufuncs import register_sisl_dispatch
from sisl.messages import SislError

from .grid import Grid

GridLike = "GridLike"
if TYPE_CHECKING:
from sisl.typing import GridLike

# Nothing gets exposed here
__all__ = []


@register_sisl_dispatch(module="sisl")
def copy(grid: Grid, dtype=None):
"""Copy the object, possibly changing the data-type"""
d = grid._sc_geometry_dict()
if dtype is None:
d["dtype"] = grid.dtype
else:
d["dtype"] = dtype
out = grid.__class__([1] * 3, **d)
# This also ensures the shape is copied!
out.grid = grid.grid.astype(dtype=d["dtype"])
return out


@register_sisl_dispatch(module="sisl")
def swapaxes(grid: Grid, axis_a: int, axis_b: int):
"""Swap two axes in the grid (also swaps axes in the lattice)
If ``swapaxes(0, 1)`` it returns the 0 in the 1 values.
Parameters
----------
axis_a, axis_b :
axes indices to be swapped
"""
# Create index vector
idx = _a.arangei(3)
idx[axis_b] = axis_a
idx[axis_a] = axis_b
s = np.copy(grid.shape)
d = grid._sc_geometry_dict()
d["lattice"] = d["lattice"].swapaxes(axis_a, axis_b)
d["dtype"] = grid.dtype
out = grid.__class__(s[idx], **d)
# We need to force the C-order or we loose the contiguity
out.grid = np.copy(np.swapaxes(grid.grid, axis_a, axis_b), order="C")
return out


@register_sisl_dispatch(module="sisl")
def sub(grid: Grid, idx: Union[int, Sequence[int]], axis: int):
"""Retains certain indices from a specified axis.
Works exactly opposite to `remove`.
Parameters
----------
idx :
the indices of the grid axis `axis` to be retained
axis :
the axis segment from which we retain the indices `idx`
"""
idx = _a.asarrayi(idx).ravel()
shift_geometry = False
if len(idx) > 1:
if np.allclose(np.diff(idx), 1):
shift_geometry = not grid.geometry is None

if shift_geometry:
out = grid._copy_sub(len(idx), axis)
min_xyz = out.dcell[axis, :] * idx[0]
# Now shift the geometry according to what is retained
geom = out.geometry.translate(-min_xyz)
geom.set_lattice(out.lattice)
out.set_geometry(geom)
else:
out = grid._copy_sub(len(idx), axis, scale_geometry=True)

# Remove the indices
# First create the opposite, index
if axis == 0:
out.grid[:, :, :] = grid.grid[idx, :, :]
elif axis == 1:
out.grid[:, :, :] = grid.grid[:, idx, :]
elif axis == 2:
out.grid[:, :, :] = grid.grid[:, :, idx]

return out


@register_sisl_dispatch(module="sisl")
def remove(grid: Grid, idx: Union[int, Sequence[int]], axis: int):
"""Removes certain indices from a specified axis.
Works exactly opposite to `sub`.
Parameters
----------
idx :
the indices of the grid axis `axis` to be removed
axis :
the axis segment from which we remove all indices `idx`
"""
ret_idx = np.delete(_a.arangei(grid.shape[axis]), _a.asarrayi(idx))
return grid.sub(ret_idx, axis)


@register_sisl_dispatch(module="sisl")
def append(grid: Grid, other: GridLike, axis: int):
"""Appends other `Grid` to this grid along axis"""
shape = list(grid.shape)
other = grid.new(other)
shape[axis] += other.shape[axis]
d = grid._sc_geometry_dict()
if "geometry" in d:
if not other.geometry is None:
d["geometry"] = d["geometry"].append(other.geometry, axis)
else:
d["geometry"] = other.geometry
d["lattice"] = grid.lattice.append(other.lattice, axis)
d["dtype"] = grid.dtype
return grid.__class__(shape, **d)


@register_sisl_dispatch(module="sisl")
def tile(grid: Grid, reps: int, axis: int):
"""Tile grid to create a bigger one
Expand All @@ -20,8 +143,6 @@ def tile(grid: Grid, reps: int, axis: int):
Parameters
----------
grid : Grid
the object to act on
reps :
number of tiles (repetitions)
axis :
Expand Down
Loading

0 comments on commit 0816909

Please sign in to comment.