Skip to content

Commit

Permalink
Shape refactor 4
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 21, 2024
1 parent c161696 commit 26296ab
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 117 deletions.
64 changes: 39 additions & 25 deletions phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from numbers import Number
from typing import TypeVar, Tuple, Dict, Union, Optional, Sequence, Any, Callable

from . import channel
from ._shape import Shape, DimFilter, batch, instance, shape, non_batch, merge_shapes, concat_shapes, spatial, parse_dim_order, dual, auto, parse_shape_spec, DIM_FUNCTIONS, INV_CHAR
from . import channel, EMPTY_SHAPE
from ._shape import Shape, DimFilter, batch, instance, shape, non_batch, merge_shapes, concat_shapes, spatial, parse_dim_order, dual, auto, parse_shape_spec, DIM_FUNCTIONS, \
INV_CHAR, concat_shapes_, Dim
from .magic import Sliceable, Shaped, Shapable, PhiTreeNode
from ..backend import choose_backend, NoBackendFound
from ..backend._dtype import DType
Expand Down Expand Up @@ -88,7 +89,7 @@ def unstack(value, dim: DimFilter) -> tuple:
(0.0, 0.0, 0.0, 0.0, 0.0)
"""
assert isinstance(value, Sliceable) and isinstance(value, Shaped), f"Cannot unstack {type(value).__name__}. Must be Sliceable and Shaped, see https://tum-pbs.github.io/PhiML/phiml/math/magic.html"
dims = shape(value).only(dim)
dims = shape(value).only(dim, reorder=True)
if dims.rank == 0:
return value,
if dims.rank == 1:
Expand All @@ -102,7 +103,7 @@ def unstack(value, dim: DimFilter) -> tuple:
else: # multiple dimensions
if hasattr(value, '__pack_dims__'):
packed_dim = batch('_unstack')
value_packed = value.__pack_dims__(dims.names, packed_dim, pos=None)
value_packed = value.__pack_dims__(dims, packed_dim, pos=None)
if value_packed is not NotImplemented:
return unstack(value_packed, packed_dim)
unstack_dim = _any_uniform_dim(dims)
Expand Down Expand Up @@ -547,41 +548,54 @@ def rename_dims(value: PhiTreeNodeType,
Same type as `value`.
"""
if isinstance(value, Shape):
return value._replace_names_and_types(dims, names)
old_dims, new_dims = _shape_replace(value, dims, names)
return value.replace(old_dims, new_dims)
elif isinstance(value, (Number, bool)):
return value
assert isinstance(value, Shapable) and isinstance(value, Shaped), f"value must be a Shape or Shapable but got {type(value).__name__}"
dims = shape(value).only(dims).names if callable(dims) else parse_dim_order(dims)
existing_dims = shape(value).only(dims, reorder=True)
if isinstance(names, str) and names.startswith('(') and names.endswith(')'):
item_names = [s.strip() for s in names[1:-1].split(',')]
names = [shape(value)[d].with_size(item_names) for d in dims]
elif isinstance(names, str):
names = parse_dim_order(names)
elif callable(names):
names = names(**existing_dims.untyped_dict)
dims = existing_dims
assert len(dims) == len(names), f"names and dims must be of equal length but got #dims={len(dims)} and #names={len(names)}"
if not existing_dims:
old_dims, new_dims = _shape_replace(shape(value), dims, names)
if not new_dims:
return value
existing_names = [n for i, n in enumerate(names) if dims[i] in existing_dims]
existing_names = existing_dims._replace_names_and_types(existing_dims, existing_names)
# --- First try __replace_dims__ ---
if hasattr(value, '__replace_dims__'):
result = value.__replace_dims__(existing_dims.names, existing_names, **kwargs)
result = value.__replace_dims__(old_dims.names, new_dims, **kwargs)
if result is not NotImplemented:
return result
# --- Next try Tree Node ---
if isinstance(value, PhiTreeNode):
return tree_map(rename_dims, value, all_attributes, treat_layout_as_leaf=True, dims=existing_dims, names=existing_names, **kwargs)
return tree_map(rename_dims, value, all_attributes, treat_layout_as_leaf=True, dims=old_dims, names=new_dims, **kwargs)
# --- Fallback: unstack and stack ---
if shape(value).only(existing_dims).volume > 8:
warnings.warn(f"rename_dims() default implementation is slow on large dimensions ({existing_dims}). Please implement __replace_dims__() for {type(value).__name__} as defined in phiml.math.magic", RuntimeWarning, stacklevel=2)
for old_name, new_dim in zip(existing_dims.names, existing_names):
if shape(value).only(old_dims).volume > 8:
warnings.warn(f"rename_dims() default implementation is slow on large dimensions ({old_dims}). Please implement __replace_dims__() for {type(value).__name__} as defined in phiml.math.magic", RuntimeWarning, stacklevel=2)
for old_name, new_dim in zip(old_dims.names, new_dims):
value = stack(unstack(value, old_name), new_dim, **kwargs)
return value


def _shape_replace(shape: Shape, dims: DimFilter, new: DimFilter) -> Tuple[Shape, Shape]: # _replace_names_and_types
if callable(dims):
existing = dims(shape)
elif isinstance(dims, Shape):
existing = dims.only(shape)
else:
dims = parse_dim_order(dims)
existing = shape.only(dims, reorder=True)
if not existing:
return EMPTY_SHAPE, EMPTY_SHAPE
if isinstance(new, str) and new.startswith('(') and new.endswith(')'):
item_names = [s.strip() for s in new[1:-1].split(',')]
new = concat_shapes_(*[d.with_size(item_names) for d in existing])
elif isinstance(new, str):
new = parse_dim_order(new)
assert len(new) == len(dims), f"Number of names {new} does not match dims to replace {dims}"
new = concat_shapes_(*[Dim(n, dim.size, dim.dim_type, dim.slice_names) for dim, n in zip(existing, new)])
elif callable(new):
new = new(**existing.untyped_dict)
assert len(dims) == len(new), f"Number of names {new} does not match dims to replace {dims}"
return existing, new



def b2i(value: PhiTreeNodeType) -> PhiTreeNodeType:
""" Change the type of all *batch* dimensions of `value` to *instance* dimensions. See `rename_dims`. """
return rename_dims(value, batch, instance)
Expand Down Expand Up @@ -668,7 +682,7 @@ def pack_dims(value, dims: DimFilter, packed_dim: Union[Shape, str], pos: Option
return unpack_dim(value, dims, packed_dim, **kwargs)
# --- First try __pack_dims__ ---
if hasattr(value, '__pack_dims__'):
result = value.__pack_dims__(dims.names, packed_dim, pos, **kwargs)
result = value.__pack_dims__(dims, packed_dim, pos, **kwargs)
if result is not NotImplemented:
return result
# --- Next try Tree Node ---
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from . import extrapolation as extrapolation
from ._magic_ops import stack, rename_dims, concat, tree_map, value_attributes
from ._ops import choose_backend_t, reshaped_native, reshaped_tensor
from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, instance, dual, auto, non_batch
from ._shape import Shape, channel, batch, spatial, DimFilter, parse_dim_order, instance, dual, auto, non_batch, after_gather
from ._tensors import Tensor, wrap, tensor, reshaped_numpy
from .extrapolation import Extrapolation
from .magic import PhiTreeNode
Expand Down Expand Up @@ -910,6 +910,6 @@ def perform_query(np_query):
def perform_query(np_vectors, np_query):
return KDTree(np_vectors).query(np_query)[1]
native_idx = b.numpy_call(perform_query, (query.shape.without(batch(vectors)).non_channel.volume,), DType(int, 64), native_vectors, native_query)
native_multi_idx = choose_backend(native_idx).unravel_index(native_idx, vectors.shape.after_gather(i).non_channel.sizes)
native_multi_idx = choose_backend(native_idx).unravel_index(native_idx, after_gather(vectors.shape, i).non_channel.sizes)
result.append(reshaped_tensor(native_multi_idx, [query_i.shape.non_channel, index_dim or math.EMPTY_SHAPE]))
return stack(result, batch(vectors))
Loading

0 comments on commit 26296ab

Please sign in to comment.