Skip to content

Commit

Permalink
Fix constant extrapolation when already partially padded
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Sep 20, 2023
1 parent 593ffb5 commit c1a0774
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 5 deletions.
2 changes: 1 addition & 1 deletion phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def inner_concat(*values):
return result


def pad(value: Tensor, widths: Union[dict, tuple], mode: Union['e_.Extrapolation', Tensor, Number, str] = 0, **kwargs) -> Tensor:
def pad(value: Tensor, widths: Union[dict, tuple], mode: Union['e_.Extrapolation', Tensor, Number, str, dict] = 0, **kwargs) -> Tensor:
"""
Pads a tensor along the specified dimensions, determining the added values using the given extrapolation.
Unlike `Extrapolation.pad()`, this function can handle negative widths which slice off outer values.
Expand Down
20 changes: 16 additions & 4 deletions phiml/math/extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def pad(self, value: Tensor, widths: dict, already_padded: Optional[dict] = None
if width[True] > 0:
values.append(self.pad_values(value, width[True], dim, True, already_padded=already_padded, **kwargs))
value = concat(values, dim)
already_padded[dim] = width
if dim in already_padded:
already_padded[dim] = tuple(i+j for i, j in zip(already_padded[dim], width))
else:
already_padded[dim] = width
return value

def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, already_padded: Optional[dict] = None, **kwargs) -> Tensor:
Expand Down Expand Up @@ -260,10 +263,9 @@ def is_flexible(self) -> bool:

def pad(self, value: Tensor, widths: dict, already_padded: Optional[dict] = None, **kwargs) -> Tensor:
"""Pads a tensor using constant values."""
derivative = get_spatial_derivative_order()
pad_value = self.value if derivative == 0 else math.wrap(0)
value = value._simplify()
if isinstance(value, NativeTensor):
pad_value = self._get_pad_value(already_padded)
backend = choose_backend(value._native, pad_value.native())
for dim in pad_value.shape.non_batch.names:
assert dim in value.shape, f"Cannot pad tensor {value.shape} with extrapolation {pad_value.shape} because non-batch dimension '{dim}' is missing."
Expand Down Expand Up @@ -292,7 +294,17 @@ def pad(self, value: Tensor, widths: dict, already_padded: Optional[dict] = None

def pad_values(self, value: Tensor, width: int, dim: str, upper_edge: bool, already_padded: Optional[dict] = None, **kwargs) -> Tensor:
shape = value.shape.after_gather({dim: slice(0, width)})
return math.expand(self.value, shape)
pad_value = self._get_pad_value(already_padded)
return math.expand(pad_value, shape)

def _get_pad_value(self, already_padded: Optional[dict]):
if get_spatial_derivative_order() == 0:
if already_padded:
return ZERO.pad(self.value, already_padded)
else:
return self.value
else:
return math.wrap(0)

def sparse_pad_values(self, value: Tensor, connectivity: Tensor, dim: str, upper_edge: bool, **kwargs) -> Tensor:
return math.expand(self.value, dual(connectivity) & non_dual(value))
Expand Down
5 changes: 5 additions & 0 deletions tests/commit/math/test_extrapolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,8 @@ def test_as_extrapolation(self):
self.assertEqual(ZERO, as_extrapolation('zero'))
self.assertEqual(combine_by_direction(ZERO, 1), as_extrapolation({'normal': 0, 'tangential': 1}))
self.assertEqual(combine_sides(x=1, y=ZERO_GRADIENT), as_extrapolation({'x': wrap(1), 'y': 'zero-gradient'}))

def test_constant_already_padded(self):
t = math.zeros(spatial(x=3, y=2))
p = math.pad(t, {'x': (1, 1), 'y': (1, 1)}, {'x': ZERO_GRADIENT, 'y': wrap([0, 1, 0], spatial('x'))})
math.assert_close([0, 0, 1, 0, 0], p.y[0])

0 comments on commit c1a0774

Please sign in to comment.