diff --git a/earth2grid/__init__.py b/earth2grid/__init__.py index 6d52f27..3fe995c 100644 --- a/earth2grid/__init__.py +++ b/earth2grid/__init__.py @@ -14,13 +14,14 @@ # limitations under the License. import torch -from earth2grid import base, healpix, latlon +from earth2grid import base, healpix, latlon, lcc from earth2grid._regrid import BilinearInterpolator, Identity, KNNS2Interpolator, Regridder __all__ = [ "base", "healpix", "latlon", + "lcc", "get_regridder", "BilinearInterpolator", "KNNS2Interpolator", @@ -36,6 +37,8 @@ def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: return src.get_bilinear_regridder_to(dest.lat, dest.lon) elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid): return src.get_bilinear_regridder_to(dest.lat, dest.lon) + elif isinstance(src, lcc.LambertConformalConicGrid): + return src.get_bilinear_regridder_to(dest.lat, dest.lon) elif isinstance(src, healpix.Grid): return src.get_bilinear_regridder_to(dest.lat, dest.lon) elif isinstance(dest, healpix.Grid): diff --git a/earth2grid/_regrid.py b/earth2grid/_regrid.py index 5e9e5e8..daba8bb 100644 --- a/earth2grid/_regrid.py +++ b/earth2grid/_regrid.py @@ -44,7 +44,7 @@ def forward(self, z): weight = self.weight.view(-1, p) # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode='sum') + output = torch.nn.functional.embedding_bag(index, zrs, per_sample_weights=weight, mode="sum") output = output.T.view(*shape, -1) return output.reshape(list(shape) + output_shape) @@ -173,12 +173,12 @@ def forward(self, z: torch.Tensor): *shape, y, x = z.shape zrs = z.view(-1, y * x).T # using embedding bag is 2x faster on cpu and 4x on gpu. - output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode="sum") interpolated = torch.full( [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device ) - interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) - interpolated = interpolated.T.view(*shape, self.mask.numel()) + interpolated.masked_scatter_(self.mask.view(-1, 1), output) + interpolated = interpolated.T.view(*shape, *self.mask.shape) return interpolated diff --git a/earth2grid/lcc.py b/earth2grid/lcc.py new file mode 100644 index 0000000..55146d0 --- /dev/null +++ b/earth2grid/lcc.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import torch + +from earth2grid import base +from earth2grid._regrid import BilinearInterpolator + +try: + import pyvista as pv +except ImportError: + pv = None + +__all__ = [ + "LambertConformalConicProjection", + "LambertConformalConicGrid", + "HRRR_CONUS_PROJECTION", + "HRRR_CONUS_GRID", +] + + +class LambertConformalConicProjection: + def __init__(self, lat0: float, lon0: float, lat1: float, lat2: float, radius: float): + """ + + Args: + lat0: latitude of origin (degrees) + lon0: longitude of origin (degrees) + lat1: first standard parallel (degrees) + lat2: second standard parallel (degrees) + radius: radius of sphere (m) + + """ + + self.lon0 = lon0 + self.lat0 = lat0 + self.lat1 = lat1 + self.lat2 = lat2 + self.radius = radius + + c1 = np.cos(np.deg2rad(lat1)) + c2 = np.cos(np.deg2rad(lat2)) + t1 = np.tan(np.pi / 4 + np.deg2rad(lat1) / 2) + t2 = np.tan(np.pi / 4 + np.deg2rad(lat2) / 2) + + if np.abs(lat1 - lat2) < 1e-8: + self.n = np.sin(np.deg2rad(lat1)) + else: + self.n = np.log(c1 / c2) / np.log(t2 / t1) + + self.RF = radius * c1 * np.power(t1, self.n) / self.n + self.rho0 = self._rho(lat0) + + def _rho(self, lat): + return self.RF / np.power(np.tan(np.pi / 4 + np.deg2rad(lat) / 2), self.n) + + def _theta(self, lon): + """ + Angle of deviation (in radians) of the projected grid from the regular grid, + for a given longitude (in degrees). + + To convert to U and V on the projected grid to easterly / northerly components: + UN = cos(theta) * U + sin(theta) * V + VN = - sin(theta) * U + cos(theta) * V + """ + # center about reference longitude + delta_lon = lon - self.lon0 + delta_lon = delta_lon - np.round(delta_lon / 360) * 360 # convert to [-180, 180] + return self.n * np.deg2rad(delta_lon) + + def project(self, lat, lon): + """ + Compute the projected x,y from lat,lon. + """ + rho = self._rho(lat) + theta = self._theta(lon) + + x = rho * np.sin(theta) + y = self.rho0 - rho * np.cos(theta) + return x, y + + def inverse_project(self, x, y): + """ + Compute the lat,lon from the projected x,y. + """ + rho = np.hypot(x, self.rho0 - y) + theta = np.arctan2(x, self.rho0 - y) + + lat = np.rad2deg(2 * np.arctan(np.power(self.RF / rho, 1 / self.n))) - 90 + lon = self.lon0 + np.rad2deg(theta / self.n) + return lat, lon + + +# Projection used by HRRR CONUS (Continental US) data +# https://rapidrefresh.noaa.gov/hrrr/HRRR_conus.domain.txt +HRRR_CONUS_PROJECTION = LambertConformalConicProjection(lon0=-97.5, lat0=38.5, lat1=38.5, lat2=38.5, radius=6371229.0) + + +class LambertConformalConicGrid(base.Grid): + # nothing here is specific to the projection, so could be shared by any projected rectilinear grid + def __init__(self, projection: LambertConformalConicProjection, x, y): + """ + Args: + projection: LambertConformalConicProjection object + x: range of x values + y: range of y values + + """ + self.projection = projection + + self.x = np.array(x) + self.y = np.array(y) + + @property + def lat_lon(self): + mesh_x, mesh_y = np.meshgrid(self.x, self.y) + return self.projection.inverse_project(mesh_x, mesh_y) + + @property + def lat(self): + return self.lat_lon[0] + + @property + def lon(self): + return self.lat_lon[1] + + @property + def shape(self): + return (len(self.y), len(self.x)) + + def __getitem__(self, idxs): + yidxs, xidxs = idxs + return LambertConformalConicGrid(self.projection, x=self.x[xidxs], y=self.y[yidxs]) + + def get_bilinear_regridder_to(self, lat: np.ndarray, lon: np.ndarray): + """Get regridder to the specified lat and lon points""" + + x, y = self.projection.project(lat, lon) + + return BilinearInterpolator( + x_coords=torch.from_numpy(self.x), + y_coords=torch.from_numpy(self.y), + x_query=torch.from_numpy(x), + y_query=torch.from_numpy(y), + ) + + def visualize(self, data): + raise NotImplementedError() + + def to_pyvista(self): + if pv is None: + raise ImportError("Need to install pyvista") + + lat, lon = self.lat_lon + y = np.cos(np.deg2rad(lat)) * np.sin(np.deg2rad(lon)) + x = np.cos(np.deg2rad(lat)) * np.cos(np.deg2rad(lon)) + z = np.sin(np.deg2rad(lat)) + grid = pv.StructuredGrid(x, y, z) + return grid + + +def hrrr_conus_grid(ix0=0, iy0=0, nx=1799, ny=1059): + # coordinates of point in top-left corner + lat0 = 21.138123 + lon0 = 237.280472 + # grid length (m) + scale = 3000.0 + # coordinates on projected space + x0, y0 = HRRR_CONUS_PROJECTION.project(lat0, lon0) + + x = [x0 + i * scale for i in range(ix0, ix0 + nx)] + y = [y0 + i * scale for i in range(iy0, iy0 + ny)] + + return LambertConformalConicGrid(HRRR_CONUS_PROJECTION, x, y) + + +# Grid used by HRRR CONUS (Continental US) data +HRRR_CONUS_GRID = hrrr_conus_grid() diff --git a/tests/test_lcc.py b/tests/test_lcc.py new file mode 100644 index 0000000..536068a --- /dev/null +++ b/tests/test_lcc.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +import numpy as np +import pytest +import torch + +from earth2grid.lcc import HRRR_CONUS_GRID + + +def test_grid_shape(): + assert HRRR_CONUS_GRID.lat.shape == HRRR_CONUS_GRID.shape + assert HRRR_CONUS_GRID.lon.shape == HRRR_CONUS_GRID.shape + + +lats = np.array( + [ + [21.138123, 21.801926, 22.393631, 22.911015], + [23.636763, 24.328228, 24.944668, 25.48374], + [26.155672, 26.875362, 27.517046, 28.078257], + [28.69017, 29.438608, 30.106009, 30.68978], + ] +) + +lons = np.array( + [ + [-122.71953, -120.03195, -117.304596, -114.54146], + [-123.491356, -120.72898, -117.92319, -115.07828], + [-124.310524, -121.469505, -118.58098, -115.649574], + [-125.181404, -122.25762, -119.28173, -116.25871], + ] +) + + +def test_grid_vals(): + assert HRRR_CONUS_GRID.lat[0:400:100, 0:400:100] == pytest.approx(lats) + assert HRRR_CONUS_GRID.lon[0:400:100, 0:400:100] == pytest.approx(lons) + + +def test_grid_slice(): + slice_grid = HRRR_CONUS_GRID[0:400:100, 0:400:100] + assert slice_grid.lat == pytest.approx(lats) + assert slice_grid.lon == pytest.approx(lons) + + +def test_regrid_1d(): + src = HRRR_CONUS_GRID + dest_lat = np.linspace(25.0, 33.0, 10) + dest_lon = np.linspace(-123, -98, 10) + regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon) + src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape) + out_lat = regrid(src_lat) + + assert torch.allclose(out_lat, torch.tensor(dest_lat)) + + +def test_regrid_2d(): + src = HRRR_CONUS_GRID + dest_lat, dest_lon = np.meshgrid(np.linspace(25.0, 33.0, 10), np.linspace(-123, -98, 12)) + regrid = src.get_bilinear_regridder_to(dest_lat, dest_lon) + src_lat = torch.broadcast_to(torch.tensor(src.lat), src.shape) + out_lat = regrid(src_lat) + + assert torch.allclose(out_lat, torch.tensor(dest_lat))