|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
| 15 | +import math |
| 16 | + |
15 | 17 | import einops
|
16 | 18 | import netCDF4 as nc
|
17 | 19 | import torch
|
18 | 20 |
|
19 |
| -from earth2grid import base, healpix |
20 |
| -from earth2grid.latlon import LatLonGrid |
21 |
| - |
22 | 21 |
|
23 | 22 | class TempestRegridder(torch.nn.Module):
|
24 | 23 | def __init__(self, file_path):
|
@@ -48,22 +47,104 @@ def forward(self, x):
|
48 | 47 | return y
|
49 | 48 |
|
50 | 49 |
|
| 50 | +class BilinearInterpolator(torch.nn.Module): |
| 51 | + """Bilinear interpolation for a non-uniform grid""" |
| 52 | + |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + x_coords: torch.Tensor, |
| 56 | + y_coords: torch.Tensor, |
| 57 | + x_query: torch.Tensor, |
| 58 | + y_query: torch.Tensor, |
| 59 | + fill_value=math.nan, |
| 60 | + ) -> None: |
| 61 | + """ |
| 62 | +
|
| 63 | + Args: |
| 64 | + x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order. |
| 65 | + y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order. |
| 66 | + x_query (Tensor): X-coordinates for query points, shape [N]. |
| 67 | + y_query (Tensor): Y-coordinates for query points, shape [N]. |
| 68 | + """ |
| 69 | + super().__init__() |
| 70 | + self.fill_value = fill_value |
| 71 | + |
| 72 | + # Ensure input coordinates are float for interpolation |
| 73 | + x_coords, y_coords = x_coords.double(), y_coords.double() |
| 74 | + x_query = x_query.double() |
| 75 | + y_query = y_query.double() |
| 76 | + |
| 77 | + if torch.any(x_coords[1:] < x_coords[:-1]): |
| 78 | + raise ValueError("x_coords must be in non-decreasing order.") |
| 79 | + |
| 80 | + if torch.any(y_coords[1:] < y_coords[:-1]): |
| 81 | + raise ValueError("y_coords must be in non-decreasing order.") |
| 82 | + |
| 83 | + # Find indices for the closest lower and upper bounds in x and y directions |
| 84 | + x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1 |
| 85 | + x_u_idx = x_l_idx + 1 |
| 86 | + y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1 |
| 87 | + y_u_idx = y_l_idx + 1 |
| 88 | + |
| 89 | + # fill in nan outside mask |
| 90 | + def isin(x, a, b): |
| 91 | + return (x <= b) & (x >= a) |
| 92 | + |
| 93 | + mask = ( |
| 94 | + isin(x_l_idx, 0, x_coords.size(0) - 2) |
| 95 | + & isin(x_u_idx, 1, x_coords.size(0) - 1) |
| 96 | + & isin(y_l_idx, 0, y_coords.size(0) - 2) |
| 97 | + & isin(y_u_idx, 1, y_coords.size(0) - 1) |
| 98 | + ) |
| 99 | + x_u_idx = x_u_idx[mask] |
| 100 | + x_l_idx = x_l_idx[mask] |
| 101 | + y_u_idx = y_u_idx[mask] |
| 102 | + y_l_idx = y_l_idx[mask] |
| 103 | + x_query = x_query[mask] |
| 104 | + y_query = y_query[mask] |
| 105 | + |
| 106 | + # Compute weights |
| 107 | + x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx]) |
| 108 | + x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx]) |
| 109 | + y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx]) |
| 110 | + y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx]) |
| 111 | + weights = torch.stack( |
| 112 | + [x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1 |
| 113 | + ) |
| 114 | + |
| 115 | + stride = x_coords.size(-1) |
| 116 | + index = torch.stack( |
| 117 | + [ |
| 118 | + x_l_idx + stride * y_l_idx, |
| 119 | + x_u_idx + stride * y_l_idx, |
| 120 | + x_l_idx + stride * y_u_idx, |
| 121 | + x_u_idx + stride * y_u_idx, |
| 122 | + ], |
| 123 | + dim=-1, |
| 124 | + ) |
| 125 | + self.register_buffer("weights", weights) |
| 126 | + self.register_buffer("mask", mask) |
| 127 | + self.register_buffer("index", index) |
| 128 | + |
| 129 | + def forward(self, z: torch.Tensor): |
| 130 | + """ |
| 131 | + Interpolate the field |
| 132 | +
|
| 133 | + Args: |
| 134 | + z: shape [Y, X] |
| 135 | + """ |
| 136 | + *shape, y, x = z.shape |
| 137 | + zrs = z.view(-1, y * x).T |
| 138 | + # using embedding bag is 2x faster on cpu and 4x on gpu. |
| 139 | + output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum') |
| 140 | + interpolated = torch.full( |
| 141 | + [self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device |
| 142 | + ) |
| 143 | + interpolated.masked_scatter_(self.mask.unsqueeze(-1), output) |
| 144 | + interpolated = interpolated.T.view(*shape, self.mask.numel()) |
| 145 | + return interpolated |
| 146 | + |
| 147 | + |
51 | 148 | class Identity(torch.nn.Module):
|
52 | 149 | def forward(self, x):
|
53 | 150 | return x
|
54 |
| - |
55 |
| - |
56 |
| -def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module: |
57 |
| - """Get a regridder from `src` to `dest`""" |
58 |
| - if src == dest: |
59 |
| - return Identity() |
60 |
| - elif isinstance(src, LatLonGrid) and isinstance(dest, LatLonGrid): |
61 |
| - return src.get_bilinear_regridder_to(dest.lat, dest.lon) |
62 |
| - elif isinstance(src, LatLonGrid) and isinstance(dest, healpix.Grid): |
63 |
| - return src.get_bilinear_regridder_to(dest.lat, dest.lon) |
64 |
| - elif isinstance(src, healpix.Grid): |
65 |
| - return src.get_bilinear_regridder_to(dest.lat, dest.lon) |
66 |
| - elif isinstance(dest, healpix.Grid): |
67 |
| - return src.get_healpix_regridder(dest) # type: ignore |
68 |
| - |
69 |
| - raise ValueError(src, dest, "not supported.") |
|
0 commit comments