diff --git a/earth2grid/latlon.py b/earth2grid/latlon.py index 8d55e7f..5ccd45d 100644 --- a/earth2grid/latlon.py +++ b/earth2grid/latlon.py @@ -12,6 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math + import numpy as np import torch @@ -27,7 +29,12 @@ class BilinearInterpolator(torch.nn.Module): """Bilinear interpolation for a non-uniform grid""" def __init__( - self, x_coords: torch.Tensor, y_coords: torch.Tensor, x_query: torch.Tensor, y_query: torch.Tensor + self, + x_coords: torch.Tensor, + y_coords: torch.Tensor, + x_query: torch.Tensor, + y_query: torch.Tensor, + fill_value=math.nan, ) -> None: """ @@ -38,9 +45,12 @@ def __init__( y_query (Tensor): Y-coordinates for query points, shape [N]. """ super().__init__() + self.fill_value = fill_value # Ensure input coordinates are float for interpolation - x_coords, y_coords = x_coords.float(), y_coords.float() + x_coords, y_coords = x_coords.double(), y_coords.double() + x_query = x_query.double() + y_query = y_query.double() if torch.any(x_coords[1:] < x_coords[:-1]): raise ValueError("x_coords must be in non-decreasing order.") @@ -54,11 +64,22 @@ def __init__( y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1 y_u_idx = y_l_idx + 1 - # Clip indices to ensure they are within the bounds of the input grid - x_l_idx = x_l_idx.clamp(0, x_coords.size(0) - 2) - x_u_idx = x_u_idx.clamp(1, x_coords.size(0) - 1) - y_l_idx = y_l_idx.clamp(0, y_coords.size(0) - 2) - y_u_idx = y_u_idx.clamp(1, y_coords.size(0) - 1) + # fill in nan outside mask + def isin(x, a, b): + return (x <= b) & (x >= a) + + mask = ( + isin(x_l_idx, 0, x_coords.size(0) - 2) + & isin(x_u_idx, 1, x_coords.size(0) - 1) + & isin(y_l_idx, 0, y_coords.size(0) - 2) + & isin(y_u_idx, 1, y_coords.size(0) - 1) + ) + x_u_idx = x_u_idx[mask] + x_l_idx = x_l_idx[mask] + y_u_idx = y_u_idx[mask] + y_l_idx = y_l_idx[mask] + x_query = x_query[mask] + y_query = y_query[mask] # Compute weights x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx]) @@ -69,8 +90,6 @@ def __init__( [x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1 ) - self.register_buffer("weights", weights) - stride = x_coords.size(-1) index = torch.stack( [ @@ -81,6 +100,8 @@ def __init__( ], dim=-1, ) + self.register_buffer("weights", weights) + self.register_buffer("mask", mask) self.register_buffer("index", index) def forward(self, z: torch.Tensor): @@ -93,8 +114,12 @@ def forward(self, z: torch.Tensor): *shape, y, x = z.shape zrs = z.view(-1, y * x).T # using embedding bag is 2x faster on cpu and 4x on gpu. - interpolated = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') - interpolated = interpolated.T.view(*shape, self.weights.size(0)) + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') + interpolated = torch.full( + [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device + ) + interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) + interpolated = interpolated.T.view(*shape, self.mask.numel()) return interpolated diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 56e47d5..29bf264 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -133,6 +133,7 @@ def test_interpolation(self): # Execute interpolator = BilinearInterpolator(x_coords, y_coords, x_query, y_query) + interpolator.to(input_tensor) result = interpolator(input_tensor) # Verify @@ -142,13 +143,13 @@ def test_raises_error_when_coordinates_not_increasing_x(self): x_coords = torch.linspace(1, -1, steps=32) # Example non-uniform x-coordinates y_coords = torch.linspace(-1, 1, steps=32) # Example non-uniform y-coordinates with self.assertRaises(ValueError): - BilinearInterpolator(x_coords, y_coords, [0], [0]) + BilinearInterpolator(x_coords, y_coords, torch.tensor([0]), torch.tensor([0])) def test_raises_error_when_coordinates_not_increasing_y(self): x_coords = torch.linspace(-1, 1, steps=32) # Example non-uniform x-coordinates y_coords = torch.linspace(1, -1, steps=32) # Example non-uniform y-coordinates with self.assertRaises(ValueError): - BilinearInterpolator(x_coords, y_coords, [0], [0]) + BilinearInterpolator(x_coords, y_coords, torch.tensor([0]), torch.tensor([0])) def test_interpolation_func(self): # Setup @@ -170,6 +171,7 @@ def func(x, y): # Execute interpolator = BilinearInterpolator(x_coords, y_coords, x_query, y_query) + interpolator.to(input_tensor) result = interpolator(input_tensor) # Verify @@ -178,3 +180,18 @@ def func(x, y): if torch.cuda.is_available() and torch.cuda.device_count() > 0: interpolator.cuda() interpolator(input_tensor.cuda()) + + +def test_out_of_bounds(): + x_coords = torch.tensor([0, 1, 2]).float() + y_coords = torch.tensor([0, 1, 2]).float() + + x_query = torch.tensor([-1, 3]).float() + y_query = torch.tensor([-1, 3]).float() + regrid = BilinearInterpolator(x_coords, y_coords, x_query, y_query) + + data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).float() + regrid.to(data) + output = regrid(data) + + assert torch.all(torch.isnan(output))