Skip to content

Commit

Permalink
Support objects without __getitem__ in map()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 24, 2023
1 parent a5014f8 commit e952a35
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
import numpy as np

from . import _ops as math, all_available
from ._magic_ops import stack, pack_dims, expand, unpack_dim
from ._magic_ops import stack, slice_
from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel, merge_shapes, DimFilter, shape
from ._sparse import SparseCoordinateTensor
from ._tensors import Tensor, disassemble_tree, assemble_tree, disassemble_tensors, assemble_tensors, variable_attributes, wrap, specs_equal, equality_by_shape_and_value
from ._trace import ShiftLinTracer, matrix_from_function, LinearTraceInProgress
from ..backend import Backend, NUMPY
from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER
from .magic import PhiTreeNode, Shapable
from ..backend import Backend
from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER
from ..backend._dtype import DType

X = TypeVar('X')
Expand Down Expand Up @@ -1180,12 +1180,13 @@ def map_(function, *args, dims: DimFilter = shape, range=range, unwrap_scalars=T
sliceable_kwargs = {k: v for k, v in kwargs.items() if isinstance(v, Shapable)}
extra_args = [v for v in args if not isinstance(v, Shapable)]
extra_kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, Shapable)}
dims = merge_shapes(*sliceable_args, *sliceable_kwargs.values()).only(dims)
dims = merge_shapes(*sliceable_args, *sliceable_kwargs.values(), allow_varying_sizes=True).only(dims)
assert dims.well_defined, f"All arguments must have consistent sizes for all mapped dimensions. Trying to map along {dims} but some have varying sizes (marked as None)."
assert dims.volume > 0, f"map dims must have volume > 0 but got {dims}"
results = []
for _, idx in zip(range(dims.volume), dims.meshgrid()):
idx_args = [v[idx] for v in sliceable_args]
idx_kwargs = {k: v[idx] for k, v in sliceable_kwargs.items()}
idx_args = [slice_(v, idx) for v in sliceable_args]
idx_kwargs = {k: slice(v, idx) for k, v in sliceable_kwargs.items()}
if unwrap_scalars:
idx_args = [v.native() if isinstance(v, Tensor) and v.rank == 0 else v for v in idx_args]
idx_kwargs = {k: v.native() if isinstance(v, Tensor) and v.rank == 0 else v for k, v in idx_kwargs.items()}
Expand Down

0 comments on commit e952a35

Please sign in to comment.