Skip to content

Commit

Permalink
[field] Fix deprecated Hard/SoftGeometryMask
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Holl committed Dec 2, 2024
1 parent 539cdc3 commit 5a61173
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 9 deletions.
17 changes: 8 additions & 9 deletions phi/field/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@

from phi import math
from phi.geom import Geometry
from ._field import Field
from phiml.math import Tensor
from ._field import FieldInitializer
from phiml.math import Tensor, Extrapolation


class HardGeometryMask(Field):
class HardGeometryMask(FieldInitializer):
"""
Deprecated since version 1.3. Use `phi.field.mask()` or `phi.field.resample()` instead.
Deprecated since version 2.3. Use `phi.field.mask()` or `phi.field.resample()` instead.
"""

def __init__(self, geometry: Geometry):
super().__init__(geometry, 1, 0)
self.geometry = geometry
warnings.warn("HardGeometryMask and SoftGeometryMask are deprecated. Use field.mask or field.resample instead.", DeprecationWarning, stacklevel=2)

@property
def shape(self):
return self.geometry.shape.non_channel

def _sample(self, geometry: Geometry, **kwargs) -> Tensor:
def _sample(self, geometry: Geometry, at: str, boundaries: Extrapolation, **kwargs) -> math.Tensor:
return math.to_float(self.geometry.lies_inside(geometry.center))

def __getitem__(self, item: dict):
Expand All @@ -29,14 +28,14 @@ def __getitem__(self, item: dict):

class SoftGeometryMask(HardGeometryMask):
"""
Deprecated since version 1.3. Use `phi.field.mask()` or `phi.field.resample()` instead.
Deprecated since version 2.3. Use `phi.field.mask()` or `phi.field.resample()` instead.
"""
def __init__(self, geometry: Geometry, balance: Union[Tensor, float] = 0.5):
warnings.warn("HardGeometryMask and SoftGeometryMask are deprecated. Use field.mask or field.resample instead.", DeprecationWarning, stacklevel=2)
super().__init__(geometry)
self.balance = balance

def _sample(self, geometry: Geometry, **kwargs) -> Tensor:
def _sample(self, geometry: Geometry, at: str, boundaries: Extrapolation, **kwargs) -> math.Tensor:
return self.geometry.approximate_fraction_inside(geometry, self.balance)

def __getitem__(self, item: dict):
Expand Down
20 changes: 20 additions & 0 deletions tests/commit/field/test__mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from unittest import TestCase

from phiml import math
from phi.geom import Sphere
from phi.physics._boundaries import Domain
from phi.field import *


class TestNoise(TestCase):

def test_masks(self):
domain = Domain(x=10, y=10)
sphere = Sphere(x=5, y=5, radius=2)
hard_v = domain.staggered_grid(HardGeometryMask(sphere))
hard_s = domain.grid(HardGeometryMask(sphere))
soft_v = domain.staggered_grid(SoftGeometryMask(sphere))
soft_s = domain.grid(SoftGeometryMask(sphere))
for f in [hard_v, hard_s, soft_v, soft_s]:
math.assert_close(1, f.values.max)
math.assert_close(0, f.values.min)

0 comments on commit 5a61173

Please sign in to comment.