diff --git a/phiml/math/_functional.py b/phiml/math/_functional.py index 50d9573..fb6857f 100644 --- a/phiml/math/_functional.py +++ b/phiml/math/_functional.py @@ -1163,7 +1163,7 @@ def forward_retype(obj, input_types: Dict[str, Callable]): originals = t.shape.only(dims) new_dims = originals.as_type(dim_type) for o, n in zip(originals, new_dims): - input_types[n.name] = o.dim_type + input_types[n.name] = o.type retyped.append(rename_dims(t, originals, new_dims)) return assemble_tree(tree, retyped), input_types diff --git a/phiml/math/_shape.py b/phiml/math/_shape.py index 26043e6..6768413 100644 --- a/phiml/math/_shape.py +++ b/phiml/math/_shape.py @@ -67,7 +67,11 @@ def sizes(self) -> Sequence: ... @property - def types(self) -> Sequence[str]: + def types(self) -> Sequence[Callable]: + ... + + @property + def dim_types(self) -> Sequence[str]: ... @property @@ -398,7 +402,7 @@ def size(self): ... @property - def type(self) -> str: + def type(self) -> Callable: """ Only for Shapes containing exactly one single dimension. Returns the type of the dimension. @@ -809,6 +813,9 @@ def sizes(self): def types(self): return self.type, @property + def dim_types(self): + return self.dim_type, + @property def item_names(self): return self.slice_names, @@ -818,7 +825,7 @@ def untyped_dict(self): @property def type(self) -> str: - return self.dim_type + return DIM_FUNCTIONS[self.dim_type] @property def is_uniform(self) -> bool: @@ -1175,6 +1182,9 @@ def sizes(self): return tuple([d.size for d in self.dims.values()]) @property def types(self): + return [d.type for d in self.dims.values()] + @property + def dim_types(self): return [d.dim_type for d in self.dims.values()] @property def item_names(self): @@ -1490,14 +1500,22 @@ def flipped(self, dims: Union[List[str], Tuple[str]]): def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'): dims = parse_dim_order(dims) - dim_list = list(self.dims.values()) if len(dims) == len(new): + dim_list = list(self.dims.values()) for old, new_dim in zip(dims, new): new_dim = self.dims[old]._replace(new_dim) dim_list[self.index(old)] = new_dim elif len(new) > 1 and len(dims) == 1: + dim_list = list(self.dims.values()) i = self.index(dims[0]) dim_list[i:i+1] = new + else: + assert len(new) == 1 + if len(dims) == len(self.dims): + return new + i0 = self.index(dims[0]) + dim_list = [d for n, d in self.dims.items() if n not in dims] + dim_list.insert(i0, new) return concat_shapes_(*dim_list) def as_batch(self): @@ -1565,7 +1583,10 @@ def sizes(self) -> tuple: return sum([dim.sizes for dim in self.dims.values()], ()) @property def types(self): - return sum([dim.types for dim in self.dims.values()], ()) + return [d.type for d in self.dims.values()] + @property + def dim_types(self): + return [d.dim_type for d in self.dims.values()] @property def item_names(self): return sum([dim.item_names for dim in self.dims.values()], ()) @@ -1585,7 +1606,7 @@ def size(self): @property def type(self) -> str: assert len(self.dims) == 1, f"Shape.type is only defined for shapes of rank 1 but has dims {self}" - return next(iter(self.dims.values())).dim_type + return next(iter(self.dims.values())).type @property def dim_type(self): assert len(self.dims) == 1, f"Shape.dim_type is only defined for shapes of rank 1 but has dims {self}" @@ -1883,6 +1904,12 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'): elif len(new) > 1 and len(dims) == 1: i = self.index(dims[0]) dim_list[i:i+1] = new + else: + if len(dims) == len(self.dims): + return new + i0 = self.index(dims[0]) + dim_list = [d for n, d in self.dims.items() if n not in dims] + dim_list.insert(i0, new) return concat_shapes_(*dim_list) def as_batch(self): diff --git a/phiml/math/_tensors.py b/phiml/math/_tensors.py index 52bca59..4c7458b 100644 --- a/phiml/math/_tensors.py +++ b/phiml/math/_tensors.py @@ -826,8 +826,8 @@ def __matmul__(self, other): assert non_batch(other).non_dual.size == match_primal.volume, f"Cannot multiply {self.shape} @ {other.shape} because dual dims of arg1 have no match" match_primal = non_batch(other).non_dual match_dual = self.shape.dual.only(match_primal.as_dual(), reorder=True) - left_arg = pack_dims(self, match_dual, dual('_reduce')) - right_arg = pack_dims(other, match_primal, channel('_reduce')) + left_arg = self.__pack_dims__(match_dual, dual('_reduce'), None) + right_arg = other.__pack_dims__(match_primal, channel('_reduce'), None) return dot(left_arg, '~_reduce', right_arg, '_reduce') # def __rmatmul__(self, other): diff --git a/phiml/math/_trace.py b/phiml/math/_trace.py index 1ed21bc..76aca29 100644 --- a/phiml/math/_trace.py +++ b/phiml/math/_trace.py @@ -121,6 +121,10 @@ def _with_shape_replaced(self, new_shape: Shape): def _is_tracer(self) -> bool: return True + @property + def backend(self) -> Backend: + return backend_for(self._bias, *self.val.values()) + def _getitem(self, selection: dict): starts = {dim: (item.start or 0) if isinstance(item, slice) else item for dim, item in selection.items()} new_shape = after_gather(self._shape, selection)