Skip to content

Commit

Permalink
Refactor Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 26, 2024
1 parent f38445a commit 39a7ab9
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 24 deletions.
24 changes: 13 additions & 11 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ._magic_ops import expand, pack_dims, unpack_dim, cast, value_attributes, bool_to_int, tree_map, concat, stack, unstack, rename_dims, slice_, all_attributes, squeeze, ipack
from ._shape import (Shape, EMPTY_SHAPE,
spatial, batch, channel, instance, merge_shapes, parse_dim_order, concat_shapes,
IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual, resolve_index, concat_shapes_, SHAPE_TYPES)
IncompatibleShapes, DimFilter, non_batch, dual, shape, shape as get_shape, primal, auto, non_spatial, non_dual, resolve_index, concat_shapes_, SHAPE_TYPES,
Dim)
from . import extrapolation as e_
from ._tensors import (Tensor, wrap, tensor, broadcastable_native_tensors, Dense, TensorStack,
custom_op2, compatible_tensor, variable_attributes, disassemble_tree, assemble_tree,
Expand Down Expand Up @@ -1164,31 +1165,32 @@ def broadcast_op(operation: Callable,
else:
if isinstance(iter_dims, SHAPE_TYPES):
iter_dims = iter_dims.names
dim = next(iter(iter_dims))
dim_name = next(iter(iter_dims))
dim_type = None
size = None
item_names = None
unstacked = []
for tensor in tensors:
if dim in tensor.shape.names:
unstacked_tensor = tensor._unstack(dim)
if dim_name in tensor.shape:
dim = tensor.shape[dim_name]
unstacked_tensor = tensor._unstack(dim_name)
unstacked.append(unstacked_tensor)
if size is None:
size = len(unstacked_tensor)
dim_type = tensor.shape.get_type(dim)
dim_type = dim.dim_type
else:
assert size == len(unstacked_tensor)
assert dim_type == tensor.shape.get_type(dim)
assert dim_type == dim.dim_type
if item_names is None:
item_names = tensor.shape.get_item_names(dim)
item_names = dim.slice_names
else:
unstacked.append(tensor)
result_unstacked = []
for i in range(size):
gathered = [t[i] if isinstance(t, tuple) else t for t in unstacked]
result_unstacked.append(broadcast_op(operation, gathered, iter_dims=set(iter_dims) - {dim}))
result_unstacked.append(broadcast_op(operation, gathered, iter_dims=set(iter_dims) - {dim_name}))
if not no_return:
return stack(result_unstacked, Shape((size,), (dim,), (dim_type,), (item_names,)))
return stack(result_unstacked, Dim(dim_name, size, item_names, dim_type))


def where(condition: Union[Tensor, float, int],
Expand Down Expand Up @@ -2695,7 +2697,7 @@ def boolean_mask(x, dim: DimFilter, mask: Tensor, preserve_names=False):
if is_sparse(x):
indices = nonzero(mask, list_dim=instance('_boolean_mask'))
result = x[indices]
return rename_dims(result, '_boolean_mask', mask.shape.non_channel)
return result.__replace_dims__(('_boolean_mask',), mask.shape.non_channel)
if not isinstance(x, Tensor) or is_sparse(x):
keep_slices = nonzero_slices(mask)
x_slices = [x[s] for s in keep_slices]
Expand Down Expand Up @@ -2986,7 +2988,7 @@ def scatter_forward(base_grid: Tensor, indices: Tensor, values: Tensor, indexed_
backend = backend_for(indices, values, base_grid)
native_grid = base_grid._reshaped_native([batches, *indexed_dims, channels])
native_values = values._reshaped_native([batches, lists, channels])
native_indices = indices._reshaped_native([batches, lists, channel])
native_indices = indices._reshaped_native([batches, lists, indices.shape.channel])
if mode != 'mean':
native_result = backend.scatter(native_grid, native_indices, native_values, mode=mode)
else: # mean
Expand Down
65 changes: 57 additions & 8 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,19 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape') -> 'Shap
"""
...

def replace_selection(self, names: Sequence[str], new: 'Shape') -> 'Shape':
"""
Replace some of the dims of this shape.
Args:
names: Sequence of dim names.
new: Replacement dims, must have same length as `old`.
Returns:
Copy of `self` with replaced dims.
"""
...

@property
def volume(self) -> Union[int, None]:
"""
Expand Down Expand Up @@ -1006,7 +1019,9 @@ def is_compatible(self, other):
return self.dim_type == dim.dim_type

def isdisjoint(self, other) -> bool:
return self.name not in other
if isinstance(other, SHAPE_TYPES):
return self.name not in other
return self.name not in parse_dim_order(other)

def only(self, dims: 'DimFilter', reorder=False):
if dims is None: # keep none
Expand Down Expand Up @@ -1105,8 +1120,14 @@ def flipped(self, dims: Union[List[str], Tuple[str]]):
return Dim(self.name, self.size, self.dim_type, self.slice_names[::-1])

def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
assert self.name == parse_dim_names(dims, 1)[0]
return self._replace(new)
old_names = parse_dim_order(dims)
if len(old_names) == 1 and self.name == old_names[0]:
return self._replace(new)
elif self.name in old_names:
idx = old_names.index(self.name)
return self._replace(new[idx])
else:
return self

def _replace(self, new: 'Shape'):
if self.slice_names is None or new.slice_names is not None:
Expand All @@ -1116,6 +1137,14 @@ def _replace(self, new: 'Shape'):
else: # keep item names from self
return Dim(new.name, new.size, new.dim_type, self.slice_names)

def replace_selection(self, names: Sequence[str], new: Shape):
if self.name not in names:
return self
new = new[names.index(self.name)]
if self.slice_names is None or new.item_names[0] is not None or not _size_equal(self.size, new.size):
return new
return Dim(new.name, new.size, new.dim_type, self.slice_names)

def as_batch(self):
name = _apply_prefix(self.name, BATCH_DIM) if self.dim_type == DUAL_DIM else self.name
return Dim(name, self.size, BATCH_DIM, self.slice_names)
Expand Down Expand Up @@ -1391,7 +1420,9 @@ def isdisjoint(self, other) -> bool:
return other.name in self.dims
if isinstance(other, PureShape):
return other.dim_type != self.dim_type or self.dims.keys().isdisjoint(other.dims)
return other.isdisjoint(self)
elif isinstance(other, MixedShape):
return other.isdisjoint(self)
return self.dims.keys().isdisjoint(set(parse_dim_order(other)))

def only(self, dims: 'DimFilter', reorder=False):
if not self.dims or dims is None:
Expand Down Expand Up @@ -1518,6 +1549,14 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
dim_list.insert(i0, new)
return concat_shapes_(*dim_list)

def replace_selection(self, names: Sequence[str], new: Shape):
dim_list = list(self.dims.values())
for old_name, new_dim in zip(names, new):
if old_name in self.dims:
new_dim = self.dims[old_name]._replace(new_dim)
dim_list[self.index(old_name)] = new_dim
return concat_shapes_(*dim_list)

def as_batch(self):
dims = [dim.as_batch() for dim in self.dims.values()]
return PureShape(BATCH_DIM, {dim.name: dim for dim in dims})
Expand Down Expand Up @@ -1798,8 +1837,9 @@ def isdisjoint(self, other) -> bool:
return other.name in self.dims
if isinstance(other, PureShape):
return other.isdisjoint(getattr(self, other.dim_type))
assert isinstance(other, MixedShape)
return self.dims.keys().isdisjoint(other.dims)
if isinstance(other, MixedShape):
return self.dims.keys().isdisjoint(other.dims)
return self.dims.keys().isdisjoint(set(parse_dim_order(other)))

@property
def __empty__(self):
Expand Down Expand Up @@ -1912,6 +1952,14 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
dim_list.insert(i0, new)
return concat_shapes_(*dim_list)

def replace_selection(self, names: Sequence[str], new: Shape):
dim_list = list(self.dims.values())
for old_name, new_dim in zip(names, new):
if old_name in self.dims:
new_dim = self.dims[old_name].replace_selection(old_name, new_dim)
dim_list[self.index(old_name)] = new_dim
return concat_shapes_(*dim_list)

def as_batch(self):
dims = [dim.as_batch() for dim in self.dims.values()]
return PureShape(BATCH_DIM, {dim.name: dim for dim in dims})
Expand Down Expand Up @@ -2421,7 +2469,8 @@ def merge_shapes(*objs: Union[Shape, Any], allow_varying_sizes=False) -> Shape:
i = pure_merge(*[s.instance for s in shapes], allow_varying_sizes=allow_varying_sizes)
s = pure_merge(*[s.spatial for s in shapes], allow_varying_sizes=allow_varying_sizes)
c = pure_merge(*[s.channel for s in shapes], allow_varying_sizes=allow_varying_sizes)
return MixedShape(b, d, i, s, c, {**b.dims, **d.dims, **i.dims, **s.dims, **c.dims})
dims = {**b.dims, **d.dims, **i.dims, **s.dims, **c.dims}
return MixedShape(b, d, i, s, c, dims) if dims else EMPTY_SHAPE


def pure_merge(*shapes: Shape, allow_varying_sizes: bool) -> Shape:
Expand Down Expand Up @@ -2857,7 +2906,7 @@ def transposed(self):


def to_dict(self: Shape, include_sizes=True):
result = dict(names=self.names, types=self.types, item_names=self.item_names)
result = dict(names=self.names, types=self.dim_types, item_names=self.item_names)
if include_sizes:
assert self.is_uniform, f"do_dict(Shape) only supports uniform shapes but got {self}"
result['sizes'] = self.sizes
Expand Down
6 changes: 3 additions & 3 deletions phiml/math/_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,8 @@ def _with_shape_replaced(self, new_shape: Shape):
values = self._values._with_shape_replaced(self._values.shape.replace(self._shape, new_shape))
non_vec = self._shape.without('sparse_idx')
new_non_vec = new_shape[self._shape.indices(non_vec.names)]
indices = self._indices._with_shape_replaced(self._indices.shape.replace(non_vec, new_non_vec).with_dim_size('sparse_idx', new_item_names))
m_rank = self._matrix_rank._with_shape_replaced(self._matrix_rank.shape.replace(self._shape, new_shape))
indices = self._indices._with_shape_replaced(self._indices.shape.replace_selection(non_vec.names, new_non_vec).with_dim_size('sparse_idx', new_item_names))
m_rank = self._matrix_rank._with_shape_replaced(self._matrix_rank.shape.replace_selection(self._shape.names, new_shape))
return SparseCoordinateTensor(indices, values, dense_shape, self._can_contain_double_entries, self._indices_sorted, self._indices_constant, m_rank)

def _op1(self, native_function):
Expand Down Expand Up @@ -1290,7 +1290,7 @@ def stored_values(x: Tensor, list_dim=instance('entries'), invalid='discard') ->
"""
assert invalid in ['discard', 'clamp', 'keep'], f"invalid handling must be one of 'discard', 'clamp', 'keep' but got {invalid}"
if isinstance(x, Dense):
x = Dense(x._native, x._names, x._shape[x._names])
x = Dense(x._native, x._names, x._shape[x._names], x._backend)
entries_dims = x.shape.non_batch
return pack_dims(x, entries_dims, list_dim)
if isinstance(x, TensorStack):
Expand Down
17 changes: 15 additions & 2 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2883,7 +2883,7 @@ def format_row(self: Tensor, options: PrintOptions) -> str: # all values in a s
is_vector = self.shape.name == 'vector' and self.shape.channel_rank == 1
is_dual_vector = self.shape.name == '~vector'
if (not is_vector and not is_dual_vector) if options.include_shape is None else options.include_shape:
content += f" along {colors.shape(f'{self.shape.name}{SUPERSCRIPT[self.shape.type]}')}"
content += f" along {colors.shape(f'{self.shape.name}{SUPERSCRIPT[self.shape.dim_type]}')}"
elif is_dual_vector:
content = "~" + content
else:
Expand Down Expand Up @@ -3105,7 +3105,20 @@ def unserialize_spec(spec: dict):
_BACKEND_RULES = {} # Tuple[Backend...] -> Backend


def backend_for(*values: Tensor):
def backend_for(*values: Tensor) -> Backend:
"""
Chooses an appropriate backend based on the backends of `values`.
Args:
*values: Input tensors to some operation.
Returns:
`Backend` that is compatible with all `values´.
Raises:
`NoBackendFound`: If no backend exists that can handle all `values`.
"""
backends = tuple([v.backend for v in values])
result = _BACKEND_RULES.get(backends, None)
if result is not None:
Expand Down

0 comments on commit 39a7ab9

Please sign in to comment.