Skip to content

Commit

Permalink
Shape refactor 2
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 20, 2024
1 parent 5f0c4ec commit 52b7db6
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 33 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_magic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def expand(value, *dims: Union[Shape, str], **kwargs):
if not dims.without(shape(value)): # no new dims to add
if set(dims) == set(shape(value).only(dims)): # sizes and item names might differ, though
return value
dims &= combined.shape.without('dims') # add missing non-uniform dims
dims &= combined.non_uniform_shape # add missing non-uniform dims
# --- First try __expand__
if hasattr(value, '__expand__'):
result = value.__expand__(dims, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3234,7 +3234,7 @@ def close(*tensors, rel_tolerance: Union[float, Tensor] = 1e-5, abs_tolerance: U
if all(t is tensors[0] for t in tensors):
return True
tensors = [wrap(t) for t in tensors]
if any(not tensors[0].shape.is_compatible(t.shape) for t in tensors[1:]):
if any([not tensors[0].shape.is_compatible(t.shape) for t in tensors[1:]]):
return False
c = True
for other in tensors[1:]:
Expand Down
157 changes: 141 additions & 16 deletions phiml/math/_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,12 +766,6 @@ def meshgrid(self, names=False):
"""
...

def first_index(self, names=False):
...

def are_adjacent(self, dims: Union[str, tuple, list, set, 'Shape']):
...


DimFilter = Union[str, Sequence, set, Shape, Callable, None]
try:
Expand Down Expand Up @@ -806,6 +800,7 @@ class Dim:
def __post_init__(self):
if DEBUG_CHECKS:
assert isinstance(self.name, str)
assert self.dim_type in DIM_TYPES
from ._tensors import Tensor
if isinstance(self.size, Tensor):
assert self.size.rank > 0
Expand All @@ -831,6 +826,12 @@ def __len__(self):
def rank(self):
return 1

def __bool__(self):
return True
@property
def is_empty(self) -> bool:
return False

@property
def volume(self) -> Union[int, None]:
if self.size is None or isinstance(self.size, int):
Expand Down Expand Up @@ -874,6 +875,10 @@ def non_uniform(self):
def non_uniform_shape(self):
return EMPTY_SHAPE if isinstance(self.size, int) else shape(self.size)

@property
def singleton(self):
return self if _size_equal(self.size, 1) else EMPTY_SHAPE

@property
def well_defined(self):
return self.size is not None
Expand Down Expand Up @@ -947,7 +952,7 @@ def __contains__(self, item):
return item == self.name
if isinstance(item, (tuple, list)):
return len(item) == 1 and item[0] == self.name
return len(item) == 1 and item.name == self.name
return not item or (len(item) == 1 and item.name == self.name)

def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
if dim is None:
Expand All @@ -960,6 +965,10 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
return 0
raise ValueError(f"index() requires a single dimension as input but got {dim}")

def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int, ...]:
names = dims.names if isinstance(dims, Shape) else dims
return tuple([self.index(n) for n in names])

def __getitem__(self, selection):
if isinstance(selection, Shape):
selection = selection.names
Expand Down Expand Up @@ -1000,15 +1009,27 @@ def get_item_names(self, dim: Union[str, 'Shape'], fallback_spatial=False) -> Un
def __and__(self, other):
if other is dual:
return self & self.primal.as_dual()
if not isinstance(other, Shape):
other = shape(other)
if isinstance(other, (Dim, PureShape)) and other.dim_type == self.dim_type:
return pure_merge(self, other, allow_varying_sizes=False)
elif isinstance(other, (Dim, PureShape)):
if not other:
return self
by_type = [EMPTY_SHAPE] * len(DIM_TYPES)
by_type[TYPE_INDEX[self.dim_type]] = self
by_type[TYPE_INDEX[other.dim_type]] = other
return MixedShape(*by_type, dims={self.name: self, **other.dims})
return NotImplemented

def is_compatible(self, other):
if self.name not in other:
return True
dim = other[self.name]
if not _size_equal(self.size, dim.size):
return False
return self.dim_type == dim.dim_type

def only(self, dims: 'DimFilter', reorder=False):
if dims is None: # keep none
return EMPTY_SHAPE
Expand All @@ -1035,6 +1056,16 @@ def only(self, dims: 'DimFilter', reorder=False):
raise ValueError(f"Format not understood for Shape.only(): {dims}")
return EMPTY_SHAPE

def __add__(self, other):
if isinstance(other, int):
return Dim(self.name, self.size + other, self.dim_type, None)
return concat_shapes(self, other)

def __sub__(self, other):
if isinstance(other, int):
return Dim(self.name, self.size - other, self.dim_type, None)
return self.without(other)

def without(self, dims: 'DimFilter'):
if dims is None: # subtract none
return self
Expand Down Expand Up @@ -1111,14 +1142,23 @@ class PureShape:
dims: Dict[str, Dim]

def __post_init__(self):
assert len(self.dims) != 1
if DEBUG_CHECKS:
assert len(self.dims) != 1
assert self.dim_type in DIM_TYPES
for n, dim in self.dims.items():
assert n == dim.name
assert dim.dim_type == self.dim_type

def __len__(self): # this is also used for bool(self)
return len(self.dims)
@property
def rank(self):
return len(self.dims)

@property
def is_empty(self) -> bool:
return not self.dims

@property
def volume(self) -> Union[int, None]:
result = 1
Expand All @@ -1140,7 +1180,7 @@ def name_list(self):
return list(self.dims)
@property
def sizes(self):
return [d.size for d in self.dims.values()]
return tuple([d.size for d in self.dims.values()])
@property
def types(self):
return [d.dim_type for d in self.dims.values()]
Expand Down Expand Up @@ -1179,6 +1219,11 @@ def non_uniform_shape(self):
result &= size.shape
return result

@property
def singleton(self):
dims = {n: dim for n, dim in self.dims.items() if _size_equal(dim.size, 1)}
return next(iter(dims.values())) if len(dims) == 1 else PureShape(self.dim_type, dims)

@property
def well_defined(self):
for size in self.sizes:
Expand Down Expand Up @@ -1264,6 +1309,10 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
return self.names.index(dim.name)
raise ValueError(f"index() requires a single dimension as input but got {dim}")

def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int, ...]:
names = dims.names if isinstance(dims, Shape) else dims
return tuple([self.index(n) for n in names])

def __getitem__(self, selection):
if isinstance(selection, int):
return list(self.dims.values())[selection]
Expand Down Expand Up @@ -1301,15 +1350,22 @@ def get_item_names(self, dim: Union[str, 'Shape'], fallback_spatial=False) -> Un
def __and__(self, other):
if other is dual:
return concat_shapes(self, self.primal.as_dual())
if not isinstance(other, Shape):
other = shape(other)
if isinstance(other, (Dim, PureShape)) and other.dim_type == self.dim_type:
return pure_merge(self, other, allow_varying_sizes=False)
elif isinstance(other, (Dim, PureShape)):
if not self:
return other
by_type = [EMPTY_SHAPE] * len(DIM_TYPES)
by_type[TYPE_INDEX[self.dim_type]] = self
by_type[TYPE_INDEX[other.dim_type]] = other
return MixedShape(*by_type, dims={**self.dims, **other.dims})
return NotImplemented

def is_compatible(self, other: Shape):
return all(dim.is_compatible(other) for dim in self.dims.values())

def only(self, dims: 'DimFilter', reorder=False):
if not self.dims or dims is None:
return EMPTY_SHAPE
Expand Down Expand Up @@ -1341,6 +1397,18 @@ def only(self, dims: 'DimFilter', reorder=False):
else:
return PureShape(self.dim_type, {n: dim for n, dim in self.dims.items() if n in names})

def __add__(self, other):
if isinstance(other, int):
assert self.dim_type != BATCH_DIM, f"Shape arithmetic not allowed for batch dims {self}"
return PureShape(self.dim_type, {n: dim + other for n, dim in self.dims.items()})
return concat_shapes(self, other)

def __sub__(self, other):
if isinstance(other, int):
assert self.dim_type != BATCH_DIM, f"Shape arithmetic not allowed for batch dims {self}"
return PureShape(self.dim_type, {n: dim - other for n, dim in self.dims.items()})
return self.without(other)

def without(self, dims: 'DimFilter'):
if dims is None or not self.dims: # subtract none
return self
Expand Down Expand Up @@ -1383,6 +1451,9 @@ def with_dim_size(self, dim: Union[str, 'Shape'], size: Union[int, 'math.Tensor'
raise NotImplementedError

def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Shape', int], keep_item_names=True):
if not self.dims:
assert not sizes
return self
dims = {dim.name: dim.with_size(size, keep_item_names) for dim, size in zip(self.dims.values(), sizes)}
return PureShape(self.dim_type, dims)

Expand Down Expand Up @@ -1417,12 +1488,19 @@ class MixedShape:
channel: Union[PureShape, Dim]
dims: Dict[str, Dim] # dim order

def __post_init__(self):
assert self

def __len__(self):
return len(self.dims)
@property
def rank(self):
return len(self.dims)

@property
def is_empty(self) -> bool:
return not self.dims

@property
def volume(self) -> Union[int, None]:
result = 1
Expand Down Expand Up @@ -1494,6 +1572,11 @@ def non_uniform_shape(self):
result &= size.shape
return result

@property
def singleton(self):
dims = {n: dim for n, dim in self.dims.items() if _size_equal(dim.size, 1)}
return next(iter(dims.values())) if len(dims) == 1 else merge_shapes(dims)

@property
def well_defined(self):
for size in self.sizes:
Expand Down Expand Up @@ -1570,6 +1653,10 @@ def index(self, dim: Union[str, 'Shape', None]) -> Optional[int]:
return self.names.index(dim.name)
raise ValueError(f"index() requires a single dimension as input but got {dim}")

def indices(self, dims: Union[tuple, list, 'Shape']) -> Tuple[int, ...]:
names = dims.names if isinstance(dims, Shape) else dims
return tuple([self.index(n) for n in names])

def __getitem__(self, selection):
if isinstance(selection, int):
return list(self.dims.values())[selection]
Expand Down Expand Up @@ -1612,13 +1699,22 @@ def get_item_names(self, dim: Union[str, 'Shape'], fallback_spatial=False) -> Un
def __and__(self, other):
if other is dual:
return self & self.primal.as_dual()
if not isinstance(other, Shape):
other = shape(other)
if isinstance(other, (Dim, PureShape)):
if not other:
return self
group = getattr(self, other.dim_type)
merged = pure_merge(group, other, allow_varying_sizes=False)
dims = {**self.dims, **merged.dims}
return replace(self, dims=dims, **{other.dim_type: merged})
return merge_shapes(self, other)

__rand__ = __and__

def is_compatible(self, other: Shape):
return all(dim.is_compatible(other) for dim in self.dims.values())

def only(self, dims: 'DimFilter', reorder=False):
if isinstance(dims, (Dim, PureShape)):
return getattr(self, dims.dim_type).only(dims, reorder=reorder)
Expand All @@ -1631,7 +1727,10 @@ def only(self, dims: 'DimFilter', reorder=False):
i = self.instance.only(dims, reorder=reorder)
s = self.spatial.only(dims, reorder=reorder)
c = self.channel.only(dims, reorder=reorder)
if bool(b) + bool(d) + bool(i) + bool(s) + bool(c) == 1:
type_count = bool(b) + bool(d) + bool(i) + bool(s) + bool(c)
if type_count == 0:
return EMPTY_SHAPE
if type_count == 1:
return b if b else (d if d else (i if i else (s if s else c))) # if only one has entries, return it
order = {**b.dims, **d.dims, **i.dims, **s.dims, **c.dims}
if reorder:
Expand All @@ -1653,14 +1752,29 @@ def only(self, dims: 'DimFilter', reorder=False):
# return self.dims[names[0]]
# order = {d: order[d] for d in names}
return MixedShape(b, d, i, s, c, order)

def __add__(self, other):
if isinstance(other, int):
assert not self.batch, f"Shape arithmetic not allowed for batch dims {self}"
raise NotImplementedError
return concat_shapes(self, other)

def __sub__(self, other):
if isinstance(other, int):
assert not self.batch, f"Shape arithmetic not allowed for batch dims {self}"
raise NotImplementedError
return self.without(other)

def without(self, dims: 'DimFilter'):
b = self.batch.without(dims)
d = self.dual.without(dims)
i = self.instance.without(dims)
s = self.spatial.without(dims)
c = self.channel.without(dims)
if bool(b) + bool(d) + bool(i) + bool(s) + bool(c) == 1:
type_count = bool(b) + bool(d) + bool(i) + bool(s) + bool(c)
if type_count == 0:
return EMPTY_SHAPE
if type_count == 1:
return b if b else (d if d else (i if i else (s if s else c))) # if only one has entries, return it
dims = {n: dim for n, dim in self.dims.items() if dim.without(dims)}
return MixedShape(b, d, i, s, c, dims)
Expand Down Expand Up @@ -1688,7 +1802,8 @@ def with_dim_size(self, dim: Union[str, 'Shape'], size: Union[int, 'math.Tensor'
raise NotImplementedError

def with_sizes(self, sizes: Union[Sequence[int], Sequence[Tuple[str, ...]], 'Shape', int], keep_item_names=True):
raise NotImplementedError
dims = {dim.name: dim.with_size(size, keep_item_names) for dim, size in zip(self.dims.values(), sizes)}
return PureShape(self.dim_type, )

def without_sizes(self):
raise NotImplementedError
Expand Down Expand Up @@ -2437,6 +2552,15 @@ def shape_stack(stack_dim: Shape, *shapes: Shape, stack_dim_first=False):
""" Returns the shape of a tensor created by stacking tensors with `shapes`. """
if stack_dim.rank > 1:
assert stack_dim.volume == len(shapes), f"stack_dim {stack_dim} does not match number of shapes: {len(shapes)}"
if not shapes:
return stack_dim
if len(shapes) == 1:
return stack_dim & shapes[0]
# for each dim: if new name -> add else -> merge item names, note conflicting
# delete conflicting item names
# for each merged dim: gather sizes, filter present/None, if conflicting: stack replacing missing/None by 1


names = list(stack_dim.names)
types = list(stack_dim.types)
item_names = list(stack_dim.item_names)
Expand Down Expand Up @@ -2583,15 +2707,16 @@ def after_gather(self, selection: dict) -> 'Shape':
new_size = math.to_int64(math.ceil(math.wrap((stop - start) / step)))
if new_size.rank == 0:
new_size = int(new_size) # NumPy array not allowed because not hashable
result = result._replace_single_size(sel_dim, new_size, keep_item_names=True)
result = result.with_dim_size(sel_dim, new_size, keep_item_names=True)
if step < 0:
result = result.flipped([sel_dim])
if self.get_item_names(sel_dim) is not None:
result = result._with_item_name(sel_dim, tuple(self.get_item_names(sel_dim)[sel]))
result = result.with_dim_size(sel_dim, tuple(self.get_item_names(sel_dim)[sel]))
elif isinstance(sel, (tuple, list)):
result = result._replace_single_size(sel_dim, len(sel))
if self.get_item_names(sel_dim) is not None:
result = result._with_item_name(sel_dim, tuple([self.get_item_names(sel_dim)[i] for i in sel]))
result = result.with_dim_size(sel_dim, tuple([self.get_item_names(sel_dim)[i] for i in sel]))
else:
result = result.with_dim_size(sel_dim, len(sel))
elif isinstance(sel, Tensor):
if sel.dtype.kind == bool:
raise NotImplementedError("Shape.after_gather(Tensor[bool]) not yet implemented")
Expand Down
Loading

0 comments on commit 52b7db6

Please sign in to comment.