From e5ac8c29b84567f54e8e62004ba7c75d5575fd19 Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 14 Jan 2024 13:03:14 +0100 Subject: [PATCH] [geom] Fix Point --- phi/geom/_geom.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/phi/geom/_geom.py b/phi/geom/_geom.py index 8f1cbc0fe..8ec6bea28 100644 --- a/phi/geom/_geom.py +++ b/phi/geom/_geom.py @@ -4,7 +4,7 @@ from phi import math from phi.math import Tensor, Shape, EMPTY_SHAPE, non_channel, wrap, shape, Extrapolation -from phiml.math._magic_ops import variable_attributes, expand +from phiml.math._magic_ops import variable_attributes, expand, stack from phi.math.magic import BoundDim, slicing_dict @@ -20,6 +20,9 @@ class Geometry: All geometry objects support batching. Thereby any parameter defining the geometry can be varied along arbitrary batch dims. All batch dimensions are listed in Geometry.shape. + + Property getters (`@property`, such as `shape`), save for getters, must not depend on any variables marked as *variable* via `__variable_attrs__()` as these may be `None` during tracing. + Equality checks must also take this into account. """ @property @@ -636,6 +639,7 @@ def __init__(self, location: math.Tensor): assert 'vector' in location.shape, "location must have a vector dimension" assert location.shape.get_item_names('vector') is not None, "Vector dimension needs to list spatial dimension as item names." self._location = location + self._shape = self._location.shape @property def center(self) -> Tensor: @@ -643,7 +647,7 @@ def center(self) -> Tensor: @property def shape(self) -> Shape: - return self._location.shape + return self._shape @property def faces(self) -> 'Geometry': @@ -676,6 +680,12 @@ def __hash__(self): def __variable_attrs__(self): return '_location', + def __with_attrs__(self, **updates): + if '_location' in updates: + return Point(updates['_location']) + else: + return self + @property def volume(self) -> Tensor: return math.wrap(0)