Skip to content

Commit 9d50936

Browse files
committed
Move BilinearInterpolator into earth2grid._regrid
1 parent 82506cc commit 9d50936

File tree

5 files changed

+123
-123
lines changed

5 files changed

+123
-123
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Changelog
22

33
## latest
4-
4+
- `earth2grid.latlon.BilinearInterpolator` moved to `earth2grid.BilinearInterpolator`
55

66
## 2024.8.1
77

earth2grid/__init__.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,25 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import torch
16+
1517
from earth2grid import base, healpix, latlon
16-
from earth2grid._regrid import get_regridder
18+
from earth2grid._regrid import BilinearInterpolator, Identity
19+
20+
__all__ = ["base", "healpix", "latlon", "get_regridder", "BilinearInterpolator"]
21+
22+
23+
def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
24+
"""Get a regridder from `src` to `dest`"""
25+
if src == dest:
26+
return Identity()
27+
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, latlon.LatLonGrid):
28+
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
29+
elif isinstance(src, latlon.LatLonGrid) and isinstance(dest, healpix.Grid):
30+
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
31+
elif isinstance(src, healpix.Grid):
32+
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
33+
elif isinstance(dest, healpix.Grid):
34+
return src.get_healpix_regridder(dest) # type: ignore
1735

18-
__all__ = ["base", "healpix", "latlon", "get_regridder"]
36+
raise ValueError(src, dest, "not supported.")

earth2grid/_regrid.py

+100-19
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,12 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import math
16+
1517
import einops
1618
import netCDF4 as nc
1719
import torch
1820

19-
from earth2grid import base, healpix
20-
from earth2grid.latlon import LatLonGrid
21-
2221

2322
class TempestRegridder(torch.nn.Module):
2423
def __init__(self, file_path):
@@ -48,22 +47,104 @@ def forward(self, x):
4847
return y
4948

5049

50+
class BilinearInterpolator(torch.nn.Module):
51+
"""Bilinear interpolation for a non-uniform grid"""
52+
53+
def __init__(
54+
self,
55+
x_coords: torch.Tensor,
56+
y_coords: torch.Tensor,
57+
x_query: torch.Tensor,
58+
y_query: torch.Tensor,
59+
fill_value=math.nan,
60+
) -> None:
61+
"""
62+
63+
Args:
64+
x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order.
65+
y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order.
66+
x_query (Tensor): X-coordinates for query points, shape [N].
67+
y_query (Tensor): Y-coordinates for query points, shape [N].
68+
"""
69+
super().__init__()
70+
self.fill_value = fill_value
71+
72+
# Ensure input coordinates are float for interpolation
73+
x_coords, y_coords = x_coords.double(), y_coords.double()
74+
x_query = x_query.double()
75+
y_query = y_query.double()
76+
77+
if torch.any(x_coords[1:] < x_coords[:-1]):
78+
raise ValueError("x_coords must be in non-decreasing order.")
79+
80+
if torch.any(y_coords[1:] < y_coords[:-1]):
81+
raise ValueError("y_coords must be in non-decreasing order.")
82+
83+
# Find indices for the closest lower and upper bounds in x and y directions
84+
x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1
85+
x_u_idx = x_l_idx + 1
86+
y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1
87+
y_u_idx = y_l_idx + 1
88+
89+
# fill in nan outside mask
90+
def isin(x, a, b):
91+
return (x <= b) & (x >= a)
92+
93+
mask = (
94+
isin(x_l_idx, 0, x_coords.size(0) - 2)
95+
& isin(x_u_idx, 1, x_coords.size(0) - 1)
96+
& isin(y_l_idx, 0, y_coords.size(0) - 2)
97+
& isin(y_u_idx, 1, y_coords.size(0) - 1)
98+
)
99+
x_u_idx = x_u_idx[mask]
100+
x_l_idx = x_l_idx[mask]
101+
y_u_idx = y_u_idx[mask]
102+
y_l_idx = y_l_idx[mask]
103+
x_query = x_query[mask]
104+
y_query = y_query[mask]
105+
106+
# Compute weights
107+
x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx])
108+
x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx])
109+
y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx])
110+
y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx])
111+
weights = torch.stack(
112+
[x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1
113+
)
114+
115+
stride = x_coords.size(-1)
116+
index = torch.stack(
117+
[
118+
x_l_idx + stride * y_l_idx,
119+
x_u_idx + stride * y_l_idx,
120+
x_l_idx + stride * y_u_idx,
121+
x_u_idx + stride * y_u_idx,
122+
],
123+
dim=-1,
124+
)
125+
self.register_buffer("weights", weights)
126+
self.register_buffer("mask", mask)
127+
self.register_buffer("index", index)
128+
129+
def forward(self, z: torch.Tensor):
130+
"""
131+
Interpolate the field
132+
133+
Args:
134+
z: shape [Y, X]
135+
"""
136+
*shape, y, x = z.shape
137+
zrs = z.view(-1, y * x).T
138+
# using embedding bag is 2x faster on cpu and 4x on gpu.
139+
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
140+
interpolated = torch.full(
141+
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
142+
)
143+
interpolated.masked_scatter_(self.mask.unsqueeze(-1), output)
144+
interpolated = interpolated.T.view(*shape, self.mask.numel())
145+
return interpolated
146+
147+
51148
class Identity(torch.nn.Module):
52149
def forward(self, x):
53150
return x
54-
55-
56-
def get_regridder(src: base.Grid, dest: base.Grid) -> torch.nn.Module:
57-
"""Get a regridder from `src` to `dest`"""
58-
if src == dest:
59-
return Identity()
60-
elif isinstance(src, LatLonGrid) and isinstance(dest, LatLonGrid):
61-
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
62-
elif isinstance(src, LatLonGrid) and isinstance(dest, healpix.Grid):
63-
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
64-
elif isinstance(src, healpix.Grid):
65-
return src.get_bilinear_regridder_to(dest.lat, dest.lon)
66-
elif isinstance(dest, healpix.Grid):
67-
return src.get_healpix_regridder(dest) # type: ignore
68-
69-
raise ValueError(src, dest, "not supported.")

earth2grid/latlon.py

+1-100
Original file line numberDiff line numberDiff line change
@@ -12,117 +12,18 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import math
16-
1715
import numpy as np
1816
import torch
1917

2018
from earth2grid import base
19+
from earth2grid._regrid import BilinearInterpolator
2120

2221
try:
2322
import pyvista as pv
2423
except ImportError:
2524
pv = None
2625

2726

28-
class BilinearInterpolator(torch.nn.Module):
29-
"""Bilinear interpolation for a non-uniform grid"""
30-
31-
def __init__(
32-
self,
33-
x_coords: torch.Tensor,
34-
y_coords: torch.Tensor,
35-
x_query: torch.Tensor,
36-
y_query: torch.Tensor,
37-
fill_value=math.nan,
38-
) -> None:
39-
"""
40-
41-
Args:
42-
x_coords (Tensor): X-coordinates of the input grid, shape [W]. Must be in increasing sorted order.
43-
y_coords (Tensor): Y-coordinates of the input grid, shape [H]. Must be in increasing sorted order.
44-
x_query (Tensor): X-coordinates for query points, shape [N].
45-
y_query (Tensor): Y-coordinates for query points, shape [N].
46-
"""
47-
super().__init__()
48-
self.fill_value = fill_value
49-
50-
# Ensure input coordinates are float for interpolation
51-
x_coords, y_coords = x_coords.double(), y_coords.double()
52-
x_query = x_query.double()
53-
y_query = y_query.double()
54-
55-
if torch.any(x_coords[1:] < x_coords[:-1]):
56-
raise ValueError("x_coords must be in non-decreasing order.")
57-
58-
if torch.any(y_coords[1:] < y_coords[:-1]):
59-
raise ValueError("y_coords must be in non-decreasing order.")
60-
61-
# Find indices for the closest lower and upper bounds in x and y directions
62-
x_l_idx = torch.searchsorted(x_coords, x_query, right=True) - 1
63-
x_u_idx = x_l_idx + 1
64-
y_l_idx = torch.searchsorted(y_coords, y_query, right=True) - 1
65-
y_u_idx = y_l_idx + 1
66-
67-
# fill in nan outside mask
68-
def isin(x, a, b):
69-
return (x <= b) & (x >= a)
70-
71-
mask = (
72-
isin(x_l_idx, 0, x_coords.size(0) - 2)
73-
& isin(x_u_idx, 1, x_coords.size(0) - 1)
74-
& isin(y_l_idx, 0, y_coords.size(0) - 2)
75-
& isin(y_u_idx, 1, y_coords.size(0) - 1)
76-
)
77-
x_u_idx = x_u_idx[mask]
78-
x_l_idx = x_l_idx[mask]
79-
y_u_idx = y_u_idx[mask]
80-
y_l_idx = y_l_idx[mask]
81-
x_query = x_query[mask]
82-
y_query = y_query[mask]
83-
84-
# Compute weights
85-
x_l_weight = (x_coords[x_u_idx] - x_query) / (x_coords[x_u_idx] - x_coords[x_l_idx])
86-
x_u_weight = (x_query - x_coords[x_l_idx]) / (x_coords[x_u_idx] - x_coords[x_l_idx])
87-
y_l_weight = (y_coords[y_u_idx] - y_query) / (y_coords[y_u_idx] - y_coords[y_l_idx])
88-
y_u_weight = (y_query - y_coords[y_l_idx]) / (y_coords[y_u_idx] - y_coords[y_l_idx])
89-
weights = torch.stack(
90-
[x_l_weight * y_l_weight, x_u_weight * y_l_weight, x_l_weight * y_u_weight, x_u_weight * y_u_weight], dim=-1
91-
)
92-
93-
stride = x_coords.size(-1)
94-
index = torch.stack(
95-
[
96-
x_l_idx + stride * y_l_idx,
97-
x_u_idx + stride * y_l_idx,
98-
x_l_idx + stride * y_u_idx,
99-
x_u_idx + stride * y_u_idx,
100-
],
101-
dim=-1,
102-
)
103-
self.register_buffer("weights", weights)
104-
self.register_buffer("mask", mask)
105-
self.register_buffer("index", index)
106-
107-
def forward(self, z: torch.Tensor):
108-
"""
109-
Interpolate the field
110-
111-
Args:
112-
z: shape [Y, X]
113-
"""
114-
*shape, y, x = z.shape
115-
zrs = z.view(-1, y * x).T
116-
# using embedding bag is 2x faster on cpu and 4x on gpu.
117-
output = torch.nn.functional.embedding_bag(self.index, zrs, per_sample_weights=self.weights, mode='sum')
118-
interpolated = torch.full(
119-
[self.mask.numel(), zrs.shape[1]], fill_value=self.fill_value, dtype=z.dtype, device=z.device
120-
)
121-
interpolated.masked_scatter_(self.mask.unsqueeze(-1), output)
122-
interpolated = interpolated.T.view(*shape, self.mask.numel())
123-
return interpolated
124-
125-
12627
class LatLonGrid(base.Grid):
12728
def __init__(self, lat: list[float], lon: list[float]):
12829
"""

tests/test_regrid.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121

2222
import earth2grid
23-
from earth2grid.latlon import BilinearInterpolator
23+
from earth2grid import BilinearInterpolator
2424

2525

2626
@pytest.mark.parametrize("with_channels", [True, False])

0 commit comments

Comments
 (0)