Skip to content

Commit

Permalink
[geom] Add Cuboid.is_size_variable
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Apr 21, 2024
1 parent 97e7530 commit 769703a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 15 deletions.
42 changes: 30 additions & 12 deletions phi/geom/_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def shape(self):
def center(self) -> Tensor:
raise NotImplementedError()

def at(self, center: Tensor) -> 'BaseBox':
return Cuboid(center, self.half_size, self.rotation_matrix)

@property
def size(self) -> Tensor:
raise NotImplementedError(self)
Expand All @@ -57,6 +54,10 @@ def upper(self) -> Tensor:
def rotation_matrix(self) -> Optional[Tensor]:
raise NotImplementedError(self)

@property
def is_size_variable(self):
raise NotImplementedError(self)

@property
def volume(self) -> Tensor:
return math.prod(self.size, 'vector')
Expand Down Expand Up @@ -151,15 +152,15 @@ def corner_representation(self) -> 'Box':

box = corner_representation

def center_representation(self) -> 'Cuboid':
return Cuboid(self.center, self.half_size)
def center_representation(self, size_variable=True) -> 'Cuboid':
return Cuboid(self.center, self.half_size, size_variable=size_variable)

def contains(self, other: 'BaseBox'):
""" Tests if the other box lies fully inside this box. """
return np.all(other.lower >= self.lower) and np.all(other.upper <= self.upper)

def scaled(self, factor: Union[float, Tensor]) -> 'Geometry':
return Cuboid(self.center, self.half_size * factor)
return Cuboid(self.center, self.half_size * factor, size_variable=True)

@property
def boundary_elements(self) -> Dict[Any, Dict[str, slice]]:
Expand All @@ -171,7 +172,7 @@ def boundary_faces(self) -> Dict[Any, Dict[str, slice]]:

@property
def faces(self) -> 'Geometry':
return Cuboid(self.face_centers, self._half_size, self._rotation_matrix)
return Cuboid(self.face_centers, self._half_size, self._rotation_matrix, size_variable=False)

@property
def face_centers(self) -> Tensor:
Expand Down Expand Up @@ -352,6 +353,13 @@ def half_size(self):
def rotation_matrix(self) -> Optional[Tensor]:
return None

@property
def is_size_variable(self):
raise False

def at(self, center: Tensor) -> 'BaseBox':
return Cuboid(center, self.half_size, self.rotation_matrix)

def shifted(self, delta, **delta_by_dim):
return Box(self.lower + delta, self.upper + delta)

Expand Down Expand Up @@ -389,6 +397,7 @@ def __init__(self,
center: Tensor = 0,
half_size: Union[float, Tensor] = None,
rotation: Optional[Tensor] = None,
size_variable=True,
**size: Union[float, Tensor]):
"""
Args:
Expand All @@ -409,6 +418,7 @@ def __init__(self,
center = math.expand(center, channel(self._half_size))
self._center = center
self._rotation_matrix = None if rotation is None else math.rotation_matrix(rotation)
self._size_variable = size_variable

def __eq__(self, other):
if self._center is None and self._half_size is None:
Expand All @@ -426,17 +436,18 @@ def __repr__(self):

def __getitem__(self, item):
item = _keep_vector(slicing_dict(self, item))
return Cuboid(self._center[item], self._half_size[item])
return Cuboid(self._center[item], self._half_size[item], size_variable=self._size_variable)

@staticmethod
def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry':
if all(isinstance(v, Cuboid) for v in values):
return Cuboid(math.stack([v.center for v in values], dim, **kwargs), math.stack([v.half_size for v in values], dim, **kwargs))
size_variable = any([c._size_variable for c in values])
return Cuboid(math.stack([v.center for v in values], dim, **kwargs), math.stack([v.half_size for v in values], dim, **kwargs), size_variable=size_variable)
else:
return Geometry.__stack__(values, dim, **kwargs)

def __variable_attrs__(self):
return '_center', '_half_size'
return ('_center', '_half_size') if self._size_variable else ('_center',)

def __value_attrs__(self):
return '_center',
Expand Down Expand Up @@ -471,12 +482,19 @@ def upper(self):
def rotation_matrix(self) -> Optional[Tensor]:
return self._rotation_matrix

@property
def is_size_variable(self):
return self._size_variable

def at(self, center: Tensor) -> 'BaseBox':
return Cuboid(center, self.half_size, self.rotation_matrix, size_variable=self._size_variable)

def rotated(self, angle) -> Geometry:
if self._rotation_matrix is None:
return Cuboid(self._center, self._half_size, angle)
return Cuboid(self._center, self._half_size, angle, size_variable=self._size_variable)
else:
matrix = self._rotation_matrix @ (angle if dual(angle) else math.rotation_matrix(angle))
return Cuboid(self._center, self._half_size, matrix)
return Cuboid(self._center, self._half_size, matrix, size_variable=self._size_variable)

def bounding_half_extent(self):
if self._rotation_matrix is not None:
Expand Down
6 changes: 3 additions & 3 deletions phi/geom/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __getitem__(self, item):
return UniformGrid(resolution, bounds)

def __pack_dims__(self, dims: Tuple[str, ...], packed_dim: Shape, pos: Optional[int], **kwargs) -> 'Cuboid':
return math.pack_dims(self.center_representation(), dims, packed_dim, pos, **kwargs)
return math.pack_dims(self.center_representation(size_variable=False), dims, packed_dim, pos, **kwargs)

@staticmethod
def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry':
Expand All @@ -158,7 +158,7 @@ def __stack__(values: tuple, dim: Shape, **kwargs) -> 'Geometry':

def list_cells(self, dim_name):
center = math.pack_dims(self.center, self._shape.spatial.names, dim_name)
return Cuboid(center, self.half_size)
return Cuboid(center, self.half_size, size_variable=False)

def stagger(self, dim: str, lower: bool, upper: bool):
dim_mask = np.array(self.resolution.mask(dim))
Expand Down Expand Up @@ -191,7 +191,7 @@ def shifted(self, delta: Tensor, **delta_by_dim) -> BaseBox:
return UniformGrid(self.resolution, self.bounds.shifted(delta))
else:
center = self.center + delta
return Cuboid(center, self.half_size)
return Cuboid(center, self.half_size, size_variable=False)

def rotated(self, angle) -> Geometry:
raise NotImplementedError("Grids cannot be rotated. Use center_representation() to convert it to Cuboids first.")
Expand Down

0 comments on commit 769703a

Please sign in to comment.