From 769703a0373a9aeb5c6b84a6fe8e875e3e3ce047 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 21 Apr 2024 17:18:49 +0200 Subject: [PATCH] [geom] Add Cuboid.is_size_variable --- phi/geom/_box.py | 42 ++++++++++++++++++++++++++++++------------ phi/geom/_grid.py | 6 +++--- 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/phi/geom/_box.py b/phi/geom/_box.py index 335c59fd1..a9bd28fe7 100644 --- a/phi/geom/_box.py +++ b/phi/geom/_box.py @@ -34,9 +34,6 @@ def shape(self): def center(self) -> Tensor: raise NotImplementedError() - def at(self, center: Tensor) -> 'BaseBox': - return Cuboid(center, self.half_size, self.rotation_matrix) - @property def size(self) -> Tensor: raise NotImplementedError(self) @@ -57,6 +54,10 @@ def upper(self) -> Tensor: def rotation_matrix(self) -> Optional[Tensor]: raise NotImplementedError(self) + @property + def is_size_variable(self): + raise NotImplementedError(self) + @property def volume(self) -> Tensor: return math.prod(self.size, 'vector') @@ -151,15 +152,15 @@ def corner_representation(self) -> 'Box': box = corner_representation - def center_representation(self) -> 'Cuboid': - return Cuboid(self.center, self.half_size) + def center_representation(self, size_variable=True) -> 'Cuboid': + return Cuboid(self.center, self.half_size, size_variable=size_variable) def contains(self, other: 'BaseBox'): """ Tests if the other box lies fully inside this box. """ return np.all(other.lower >= self.lower) and np.all(other.upper <= self.upper) def scaled(self, factor: Union[float, Tensor]) -> 'Geometry': - return Cuboid(self.center, self.half_size * factor) + return Cuboid(self.center, self.half_size * factor, size_variable=True) @property def boundary_elements(self) -> Dict[Any, Dict[str, slice]]: @@ -171,7 +172,7 @@ def boundary_faces(self) -> Dict[Any, Dict[str, slice]]: @property def faces(self) -> 'Geometry': - return Cuboid(self.face_centers, self._half_size, self._rotation_matrix) + return Cuboid(self.face_centers, self._half_size, self._rotation_matrix, size_variable=False) @property def face_centers(self) -> Tensor: @@ -352,6 +353,13 @@ def half_size(self): def rotation_matrix(self) -> Optional[Tensor]: return None + @property + def is_size_variable(self): + raise False + + def at(self, center: Tensor) -> 'BaseBox': + return Cuboid(center, self.half_size, self.rotation_matrix) + def shifted(self, delta, **delta_by_dim): return Box(self.lower + delta, self.upper + delta) @@ -389,6 +397,7 @@ def __init__(self, center: Tensor = 0, half_size: Union[float, Tensor] = None, rotation: Optional[Tensor] = None, + size_variable=True, **size: Union[float, Tensor]): """ Args: @@ -409,6 +418,7 @@ def __init__(self, center = math.expand(center, channel(self._half_size)) self._center = center self._rotation_matrix = None if rotation is None else math.rotation_matrix(rotation) + self._size_variable = size_variable def __eq__(self, other): if self._center is None and self._half_size is None: @@ -426,17 +436,18 @@ def __repr__(self): def __getitem__(self, item): item = _keep_vector(slicing_dict(self, item)) - return Cuboid(self._center[item], self._half_size[item]) + return Cuboid(self._center[item], self._half_size[item], size_variable=self._size_variable) @staticmethod def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry': if all(isinstance(v, Cuboid) for v in values): - return Cuboid(math.stack([v.center for v in values], dim, **kwargs), math.stack([v.half_size for v in values], dim, **kwargs)) + size_variable = any([c._size_variable for c in values]) + return Cuboid(math.stack([v.center for v in values], dim, **kwargs), math.stack([v.half_size for v in values], dim, **kwargs), size_variable=size_variable) else: return Geometry.__stack__(values, dim, **kwargs) def __variable_attrs__(self): - return '_center', '_half_size' + return ('_center', '_half_size') if self._size_variable else ('_center',) def __value_attrs__(self): return '_center', @@ -471,12 +482,19 @@ def upper(self): def rotation_matrix(self) -> Optional[Tensor]: return self._rotation_matrix + @property + def is_size_variable(self): + return self._size_variable + + def at(self, center: Tensor) -> 'BaseBox': + return Cuboid(center, self.half_size, self.rotation_matrix, size_variable=self._size_variable) + def rotated(self, angle) -> Geometry: if self._rotation_matrix is None: - return Cuboid(self._center, self._half_size, angle) + return Cuboid(self._center, self._half_size, angle, size_variable=self._size_variable) else: matrix = self._rotation_matrix @ (angle if dual(angle) else math.rotation_matrix(angle)) - return Cuboid(self._center, self._half_size, matrix) + return Cuboid(self._center, self._half_size, matrix, size_variable=self._size_variable) def bounding_half_extent(self): if self._rotation_matrix is not None: diff --git a/phi/geom/_grid.py b/phi/geom/_grid.py index dc24816da..638d5234d 100644 --- a/phi/geom/_grid.py +++ b/phi/geom/_grid.py @@ -149,7 +149,7 @@ def __getitem__(self, item): return UniformGrid(resolution, bounds) def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Optional[int], **kwargs) -> 'Cuboid': - return math.pack_dims(self.center_representation(), dims, packed_dim, pos, **kwargs) + return math.pack_dims(self.center_representation(size_variable=False), dims, packed_dim, pos, **kwargs) @staticmethod def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry': @@ -158,7 +158,7 @@ def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry': def list_cells(self, dim_name): center = math.pack_dims(self.center, self._shape.spatial.names, dim_name) - return Cuboid(center, self.half_size) + return Cuboid(center, self.half_size, size_variable=False) def stagger(self, dim: str, lower: bool, upper: bool): dim_mask = np.array(self.resolution.mask(dim)) @@ -191,7 +191,7 @@ def shifted(self, delta: Tensor, **delta_by_dim) -> BaseBox: return UniformGrid(self.resolution, self.bounds.shifted(delta)) else: center = self.center + delta - return Cuboid(center, self.half_size) + return Cuboid(center, self.half_size, size_variable=False) def rotated(self, angle) -> Geometry: raise NotImplementedError("Grids cannot be rotated. Use center_representation() to convert it to Cuboids first.")