diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index 6a5cd237..375aba48 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -8,14 +8,14 @@ import numpy as np from . import _ops as math, all_available -from ._magic_ops import stack, pack_dims, expand, unpack_dim +from ._magic_ops import stack, slice_ from ._shape import EMPTY_SHAPE, Shape, spatial, instance, batch, channel, merge_shapes, DimFilter, shape from ._sparse import SparseCoordinateTensor from ._tensors import Tensor, disassemble_tree, assemble_tree, disassemble_tensors, assemble_tensors, variable_attributes, wrap, specs_equal, equality_by_shape_and_value from ._trace import ShiftLinTracer, matrix_from_function, LinearTraceInProgress -from ..backend import Backend, NUMPY -from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER from .magic import PhiTreeNode, Shapable +from ..backend import Backend +from ..backend._backend import get_spatial_derivative_order, functional_derivative_evaluation, ML_LOGGER from ..backend._dtype import DType X = TypeVar('X') @@ -1180,12 +1180,13 @@ def map_(function, *args, dims: DimFilter = shape, range=range, unwrap_scalars=T sliceable_kwargs = {k: v for k, v in kwargs.items() if isinstance(v, Shapable)} extra_args = [v for v in args if not isinstance(v, Shapable)] extra_kwargs = {k: v for k, v in kwargs.items() if not isinstance(v, Shapable)} - dims = merge_shapes(*sliceable_args, *sliceable_kwargs.values()).only(dims) + dims = merge_shapes(*sliceable_args, *sliceable_kwargs.values(), allow_varying_sizes=True).only(dims) + assert dims.well_defined, f"All arguments must have consistent sizes for all mapped dimensions. Trying to map along {dims} but some have varying sizes (marked as None)." assert dims.volume > 0, f"map dims must have volume > 0 but got {dims}" results = [] for _, idx in zip(range(dims.volume), dims.meshgrid()): - idx_args = [v[idx] for v in sliceable_args] - idx_kwargs = {k: v[idx] for k, v in sliceable_kwargs.items()} + idx_args = [slice_(v, idx) for v in sliceable_args] + idx_kwargs = {k: slice(v, idx) for k, v in sliceable_kwargs.items()} if unwrap_scalars: idx_args = [v.native() if isinstance(v, Tensor) and v.rank == 0 else v for v in idx_args] idx_kwargs = {k: v.native() if isinstance(v, Tensor) and v.rank == 0 else v for k, v in idx_kwargs.items()}