Skip to content

Commit

Permalink
[geom] Fix Point
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 14, 2024
1 parent 2a2057f commit a9eeb08
Showing 1 changed file with 19 additions and 5 deletions.
24 changes: 19 additions & 5 deletions phi/geom/_geom.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import copy
import warnings
from numbers import Number
from typing import Union, Dict, Any, Tuple

from phi import math
from phi.math import Tensor, Shape, EMPTY_SHAPE, non_channel, wrap, shape, Extrapolation
from phiml.math._magic_ops import variable_attributes, expand
from phiml.math._magic_ops import variable_attributes, expand, stack
from phi.math.magic import BoundDim, slicing_dict


Expand All @@ -20,6 +21,9 @@ class Geometry:
All geometry objects support batching.
Thereby any parameter defining the geometry can be varied along arbitrary batch dims.
All batch dimensions are listed in Geometry.shape.
Property getters (`@property`, such as `shape`), save for getters, must not depend on any variables marked as *variable* via `__variable_attrs__()` as these may be `None` during tracing.
Equality checks must also take this into account.
"""

@property
Expand Down Expand Up @@ -636,14 +640,27 @@ def __init__(self, location: math.Tensor):
assert 'vector' in location.shape, "location must have a vector dimension"
assert location.shape.get_item_names('vector') is not None, "Vector dimension needs to list spatial dimension as item names."
self._location = location
self._shape = self._location.shape

def __variable_attrs__(self):
return '_location',

def __with_attrs__(self, **updates):
if '_location' in updates:
result = Point.__new__(Point)
result._location = updates['_location']
result._shape = result._location.shape if result._location is not None else self._shape
return result
else:
return self

@property
def center(self) -> Tensor:
return self._location

@property
def shape(self) -> Shape:
return self._location.shape
return self._shape

@property
def faces(self) -> 'Geometry':
Expand Down Expand Up @@ -673,9 +690,6 @@ def rotated(self, angle) -> 'Geometry':
def __hash__(self):
return hash(self._location)

def __variable_attrs__(self):
return '_location',

@property
def volume(self) -> Tensor:
return math.wrap(0)
Expand Down

0 comments on commit a9eeb08

Please sign in to comment.