Skip to content

Commit

Permalink
[field] Deprecate Field subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 13, 2023
1 parent 0d05e83 commit 772de63
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
33 changes: 33 additions & 0 deletions phi/field/_field.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import warnings
from numbers import Number
from typing import TypeVar, Callable, Union
Expand Down Expand Up @@ -152,6 +153,14 @@ def is_staggered(self):
from ._grid import StaggeredGrid
return isinstance(self, StaggeredGrid)

@property
def is_centered(self):
return not self.is_staggered

@property
def sampled_at(self):
return 'face' if self.is_staggered else 'center'


class SampledField(Field):
"""
Expand Down Expand Up @@ -481,3 +490,27 @@ def as_extrapolation(obj: Union[Extrapolation, float, Field, None]) -> Extrapola
return FieldEmbedding(obj)
else:
return math.extrapolation.as_extrapolation(obj)


def deprecated_field_class(for_class: str, parent_metaclass=None, allowed=("/phi/field/_field_math.py", "/phi/field/_grid.py")):
class _WarnOnInstanceChecks(parent_metaclass or type):

def __instancecheck__(self, instance):
caller_frame = inspect.currentframe().f_back
caller_code = caller_frame.f_code
caller_file = caller_code.co_filename
for filename in allowed:
if filename in caller_file:
break
else:
warnings.warn(f"Instance checks on {for_class} are deprecated and will be removed in version 3.0. Use the methods instance.is_grid, instance.is_point_cloud, instance.is_centered and instance.is_staggered instead.", FutureWarning, stacklevel=2)
return type.__instancecheck__(self, instance)

def __subclasscheck__(self, subclass):
warnings.warn(f"Subclass checks on {for_class} are deprecated and will be removed in version 3.0", FutureWarning, stacklevel=2)
return type.__subclasscheck__(self, subclass)

return _WarnOnInstanceChecks


# warnings.simplefilter('always', DeprecationWarning)
8 changes: 4 additions & 4 deletions phi/field/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from phi import math, geom
from phi.geom import Box, Geometry, GridCell
from ._embed import FieldEmbedding
from ._field import SampledField, Field, sample, reduce_sample, as_extrapolation
from ._field import SampledField, Field, sample, reduce_sample, as_extrapolation, deprecated_field_class
from ..geom._stack import GeometryStack
from phiml.math import Shape, NUMPY
from phiml.math._shape import spatial, channel, parse_dim_order
Expand All @@ -14,7 +14,7 @@
from phiml.math.magic import slicing_dict


class Grid(SampledField):
class Grid(SampledField, metaclass=deprecated_field_class('Grid')):
"""
Base class for `CenteredGrid` and `StaggeredGrid`.
"""
Expand Down Expand Up @@ -144,7 +144,7 @@ def uniform_values(self):
GridType = TypeVar('GridType', bound=Grid)


class CenteredGrid(Grid):
class CenteredGrid(Grid, metaclass=deprecated_field_class('CenteredGrid', parent_metaclass=type(Grid))):
"""
N-dimensional grid with values sampled at the cell centers.
A centered grid is defined through its `CenteredGrid.values` `phiml.math.Tensor`, its `CenteredGrid.bounds` `phi.geom.Box` describing the physical size, and its `CenteredGrid.extrapolation` (`phiml.math.extrapolation.Extrapolation`).
Expand Down Expand Up @@ -282,7 +282,7 @@ def closest_values(self, points: Geometry):
return math.closest_grid_values(self.values, local_points, self.extrapolation)


class StaggeredGrid(Grid):
class StaggeredGrid(Grid, metaclass=deprecated_field_class('StaggeredGrid', parent_metaclass=type(Grid))):
"""
N-dimensional grid whose vector components are sampled at the respective face centers.
A staggered grid is defined through its values tensor, its bounds describing the physical size, and its extrapolation.
Expand Down
4 changes: 2 additions & 2 deletions phi/field/_point_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from phi import math
from phi.geom import Geometry, GridCell, Box, Point
from ._field import SampledField, resample
from ._field import SampledField, resample, deprecated_field_class
from ..geom._stack import GeometryStack
from phiml.math import Tensor, instance, Shape
from phiml.math._tensors import may_vary_along
from phiml.math.extrapolation import Extrapolation, ConstantExtrapolation, PERIODIC
from phiml.math.magic import slicing_dict


class PointCloud(SampledField):
class PointCloud(SampledField, metaclass=deprecated_field_class('PointCloud')):
"""
A `PointCloud` comprises:
Expand Down

0 comments on commit 772de63

Please sign in to comment.