Skip to content

Commit

Permalink
Fix Shape.replace many->1
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 26, 2024
1 parent 53da26c commit f38445a
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 9 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,7 @@ def forward_retype(obj, input_types: Dict[str, Callable]):
originals = t.shape.only(dims)
new_dims = originals.as_type(dim_type)
for o, n in zip(originals, new_dims):
input_types[n.name] = o.dim_type
input_types[n.name] = o.type
retyped.append(rename_dims(t, originals, new_dims))
return assemble_tree(tree, retyped), input_types

Expand Down
39 changes: 33 additions & 6 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def sizes(self) -> Sequence:
...

@property
def types(self) -> Sequence[str]:
def types(self) -> Sequence[Callable]:
...

@property
def dim_types(self) -> Sequence[str]:
...

@property
Expand Down Expand Up @@ -398,7 +402,7 @@ def size(self):
...

@property
def type(self) -> str:
def type(self) -> Callable:
"""
Only for Shapes containing exactly one single dimension.
Returns the type of the dimension.
Expand Down Expand Up @@ -809,6 +813,9 @@ def sizes(self):
def types(self):
return self.type,
@property
def dim_types(self):
return self.dim_type,
@property
def item_names(self):
return self.slice_names,

Expand All @@ -818,7 +825,7 @@ def untyped_dict(self):

@property
def type(self) -> str:
return self.dim_type
return DIM_FUNCTIONS[self.dim_type]

@property
def is_uniform(self) -> bool:
Expand Down Expand Up @@ -1175,6 +1182,9 @@ def sizes(self):
return tuple([d.size for d in self.dims.values()])
@property
def types(self):
return [d.type for d in self.dims.values()]
@property
def dim_types(self):
return [d.dim_type for d in self.dims.values()]
@property
def item_names(self):
Expand Down Expand Up @@ -1490,14 +1500,22 @@ def flipped(self, dims: Union[List[str], Tuple[str]]):

def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
dims = parse_dim_order(dims)
dim_list = list(self.dims.values())
if len(dims) == len(new):
dim_list = list(self.dims.values())
for old, new_dim in zip(dims, new):
new_dim = self.dims[old]._replace(new_dim)
dim_list[self.index(old)] = new_dim
elif len(new) > 1 and len(dims) == 1:
dim_list = list(self.dims.values())
i = self.index(dims[0])
dim_list[i:i+1] = new
else:
assert len(new) == 1
if len(dims) == len(self.dims):
return new
i0 = self.index(dims[0])
dim_list = [d for n, d in self.dims.items() if n not in dims]
dim_list.insert(i0, new)
return concat_shapes_(*dim_list)

def as_batch(self):
Expand Down Expand Up @@ -1565,7 +1583,10 @@ def sizes(self) -> tuple:
return sum([dim.sizes for dim in self.dims.values()], ())
@property
def types(self):
return sum([dim.types for dim in self.dims.values()], ())
return [d.type for d in self.dims.values()]
@property
def dim_types(self):
return [d.dim_type for d in self.dims.values()]
@property
def item_names(self):
return sum([dim.item_names for dim in self.dims.values()], ())
Expand All @@ -1585,7 +1606,7 @@ def size(self):
@property
def type(self) -> str:
assert len(self.dims) == 1, f"Shape.type is only defined for shapes of rank 1 but has dims {self}"
return next(iter(self.dims.values())).dim_type
return next(iter(self.dims.values())).type
@property
def dim_type(self):
assert len(self.dims) == 1, f"Shape.dim_type is only defined for shapes of rank 1 but has dims {self}"
Expand Down Expand Up @@ -1883,6 +1904,12 @@ def replace(self, dims: Union['Shape', str, tuple, list], new: 'Shape'):
elif len(new) > 1 and len(dims) == 1:
i = self.index(dims[0])
dim_list[i:i+1] = new
else:
if len(dims) == len(self.dims):
return new
i0 = self.index(dims[0])
dim_list = [d for n, d in self.dims.items() if n not in dims]
dim_list.insert(i0, new)
return concat_shapes_(*dim_list)

def as_batch(self):
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,8 +826,8 @@ def __matmul__(self, other):
assert non_batch(other).non_dual.size == match_primal.volume, f"Cannot multiply {self.shape} @ {other.shape} because dual dims of arg1 have no match"
match_primal = non_batch(other).non_dual
match_dual = self.shape.dual.only(match_primal.as_dual(), reorder=True)
left_arg = pack_dims(self, match_dual, dual('_reduce'))
right_arg = pack_dims(other, match_primal, channel('_reduce'))
left_arg = self.__pack_dims__(match_dual, dual('_reduce'), None)
right_arg = other.__pack_dims__(match_primal, channel('_reduce'), None)
return dot(left_arg, '~_reduce', right_arg, '_reduce')

# def __rmatmul__(self, other):
Expand Down
4 changes: 4 additions & 0 deletions phiml/math/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ def _with_shape_replaced(self, new_shape: Shape):
def _is_tracer(self) -> bool:
return True

@property
def backend(self) -> Backend:
return backend_for(self._bias, *self.val.values())

def _getitem(self, selection: dict):
starts = {dim: (item.start or 0) if isinstance(item, slice) else item for dim, item in selection.items()}
new_shape = after_gather(self._shape, selection)
Expand Down

0 comments on commit f38445a

Please sign in to comment.