Skip to content

Commit

Permalink
wip upates
Browse files Browse the repository at this point in the history
  • Loading branch information
norlandrhagen committed Jun 10, 2024
1 parent 7d8ecbc commit 72e639a
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 32 deletions.
8 changes: 4 additions & 4 deletions pangeo_forge_ndpyramid/common.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Literal, Optional

import xarray as xr
import zarr
import zarr # type: ignore


def create_pyramid(
Expand All @@ -13,7 +13,7 @@ def create_pyramid(
pyramid_method: Literal["coarsen", "regrid", "reproject", "resample"] = "reproject",
pyramid_kwargs: Optional[dict] = {},
) -> zarr.storage.FSStore:
from ndpyramid.utils import set_zarr_encoding
from ndpyramid.utils import set_zarr_encoding # type: ignore

if pyramid_method not in ["reproject", "resample"]:
raise NotImplementedError(
Expand All @@ -32,12 +32,12 @@ def create_pyramid(
ds = ds.rename(rename_spatial_dims)

if pyramid_method == "reproject":
from ndpyramid.reproject import level_reproject
from ndpyramid.reproject import level_reproject # type: ignore

level_ds = level_reproject(ds, level=level, **pyramid_kwargs)

elif pyramid_method == "resample":
from ndpyramid.resample import level_resample
from ndpyramid.resample import level_resample # type: ignore

# Should we have resample specific kwargs we pass here?
level_ds = level_resample(
Expand Down
9 changes: 6 additions & 3 deletions pangeo_forge_ndpyramid/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import apache_beam as beam
import xarray as xr
import zarr
import zarr # type: ignore
from pangeo_forge_recipes.patterns import Dimension, Index
from pangeo_forge_recipes.storage import FSSpecTarget
from pangeo_forge_recipes.transforms import (
Expand Down Expand Up @@ -83,7 +83,7 @@ def expand(
datasets: beam.PCollection[Tuple[Index, xr.Dataset]],
) -> beam.PCollection[zarr.storage.FSStore]:
# Add multiscales metadata to the root of the target store
from ndpyramid.utils import get_version, multiscales_template
from ndpyramid.utils import get_version, multiscales_template # type: ignore

save_kwargs = {"levels": self.levels, "pixels_per_tile": self.pixels_per_tile}
attrs = {
Expand All @@ -98,13 +98,16 @@ def expand(
kwargs=save_kwargs,
)
}
# from StoreToZarr
# target_chunks: 'Dict[str, int]' = <factory>,
chunks = {"x": self.pixels_per_tile, "y": self.pixels_per_tile}
if self.other_chunks is not None:
chunks |= self.other_chunks

ds = xr.Dataset(attrs=attrs)

target_path = (self.target_root / self.store_name).get_mapper()
# Note: mypy typing in not happy here.
target_path = (self.target_root / self.store_name).get_mapper() # type: ignore
ds.to_zarr(store=target_path, compute=False) # noqa

# generate all pyramid levels
Expand Down
20 changes: 17 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pangeo_forge_recipes.patterns import pattern_from_file_sequence
from pangeo_forge_recipes.storage import FSSpecTarget

from .data_generation import make_pyramid
from .data_generation import make_pyramid_resample, make_pyramid_reproject


@pytest.fixture
Expand All @@ -23,8 +23,13 @@ def tmp_target(tmpdir_factory):


@pytest.fixture(scope="session")
def pyramid_datatree(levels: int = 2):
return make_pyramid(levels=levels)
def pyramid_datatree_reproject(levels: int = 2):
return make_pyramid_reproject(levels=levels)


@pytest.fixture(scope="session")
def pyramid_datatree_resample(levels: int = 2):
return make_pyramid_resample(levels=levels)


@pytest.fixture(scope="session")
Expand All @@ -34,3 +39,12 @@ def create_file_pattern():
concat_dim="time",
nitems_per_file=1,
)


@pytest.fixture(scope="session")
def create_file_pattern_gpm_imerg():
return pattern_from_file_sequence(
[str(path) for path in ["tests/data/3B-DAY.MS.MRG.3IMERG.20230729-S000000-E235959.V07B.nc4","tests/data/3B-DAY.MS.MRG.3IMERG.20230730-S000000-E235959.V07B.nc4"]],
concat_dim="time",
nitems_per_file=1,
)
33 changes: 29 additions & 4 deletions tests/data_generation.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
import rioxarray # noqa
import xarray as xr
from ndpyramid import pyramid_reproject
from ndpyramid import pyramid_reproject, pyramid_resample


def make_pyramid(levels: int):
def make_pyramid_resample(levels: int):
# ds = xr.tutorial.open_dataset("air_temperature")

files = [
'tests/data/3B-DAY.MS.MRG.3IMERG.20230729-S000000-E235959.V07B.nc4',
'tests/data/3B-DAY.MS.MRG.3IMERG.20230730-S000000-E235959.V07B.nc4'
]
ds = xr.open_mfdataset(files, engine='netcdf4', drop_variables=["time_bnds"],decode_coords="all")

ds = ds.chunk({'lat':10, 'lon':10})
ds = ds[['precipitation']]


ds = ds.rio.write_crs('EPSG:4326')
ds = ds.drop_vars('spatial_ref')

ds = ds.transpose('time', 'lat', 'lon')

# ds = ds.rename({"lon": "longitude", "lat": "latitude"})
# ds = ds.rio.write_crs("EPSG:4326")
ds = ds.isel(time=slice(0, 2))
# import pdb; pdb.set_trace()
return pyramid_resample(ds, levels=levels, x='lon', y='lat')


def make_pyramid_reproject(levels: int):
ds = xr.tutorial.open_dataset("air_temperature")

ds = ds.rename({"lon": "longitude", "lat": "latitude"})
ds = ds.rio.write_crs("EPSG:4326")
ds = ds.isel(time=slice(0, 2))

# other_chunks added to e2e pass of pyramid b/c target_chunks invert_meshgrid error
return pyramid_reproject(ds, levels=levels, other_chunks={"time": 1})
return pyramid_reproject(ds, levels=levels)
95 changes: 77 additions & 18 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,71 @@
import os
from dataclasses import dataclass

import apache_beam as beam
import datatree as dt
import xarray as xr
from datatree.testing import assert_isomorphic
from pangeo_forge_recipes.transforms import OpenWithXarray, StoreToZarr

from pangeo_forge_ndpyramid.transforms import StoreToPyramid


# TODO: We should parameterize the reprojection methods available in ndpyramid
# TODO: Test names and attrs
def test_pyramid(
pyramid_datatree,
def test_pyramid_resample(
pyramid_datatree_resample, create_file_pattern_gpm_imerg, pipeline, tmp_target
):
pattern = create_file_pattern_gpm_imerg
@dataclass
class Transpose(beam.PTransform):
"""Transpose dim order for pyresample"""

def _transpose(self, ds: xr.Dataset) -> xr.Dataset:
ds = ds[['precipitation']]
ds = ds.transpose("time", "lat", "lon")

return ds

def expand(self, pcoll):
return pcoll | "Transpose" >> beam.MapTuple(
lambda k, v: (k, self._transpose(v))
)

with pipeline as p:
(
p
| beam.Create(pattern.items())
| OpenWithXarray(file_type=pattern.file_type)
| Transpose()
| "Write Pyramid Levels"
>> StoreToPyramid(
target_root=tmp_target,
store_name="pyramid",
levels=2,
epsg_code="4326",
pyramid_method="resample",
pyramid_kwargs={"x": "lon", "y": "lat"},
combine_dims=pattern.combine_dim_keys,
)
)
pgf_dt = dt.open_datatree(
os.path.join(tmp_target.root_path, "pyramid"),
engine="zarr",
consolidated=False,
chunks={},
)
assert_isomorphic(
pgf_dt, pyramid_datatree_resample
) # every node has same # of children
xr.testing.assert_allclose(
pgf_dt["0"].to_dataset(), pyramid_datatree_resample["0"].to_dataset()
)
xr.testing.assert_allclose(
pgf_dt["1"].to_dataset(), pyramid_datatree_resample["1"].to_dataset()
)


# TODO: Test names and attrs
def test_pyramid_reproject(
pyramid_datatree_reproject,
create_file_pattern,
pipeline,
tmp_target,
Expand All @@ -29,17 +84,20 @@ def test_pyramid(
store_name="store",
combine_dims=pattern.combine_dim_keys,
)
process | "Write Pyramid Levels" >> StoreToPyramid(
target_root=tmp_target,
store_name="pyramid",
levels=2,
epsg_code="4326",
rename_spatial_dims={"lon": "longitude", "lat": "latitude"},
combine_dims=pattern.combine_dim_keys,
)

import datatree as dt
from datatree.testing import assert_isomorphic
(
process
| "Write Pyramid Levels"
>> StoreToPyramid(
target_root=tmp_target,
store_name="pyramid",
levels=2,
epsg_code="4326",
pyramid_method="reproject",
rename_spatial_dims={"lon": "x", "lat": "y"},
combine_dims=pattern.combine_dim_keys,
)
)

assert xr.open_dataset(
os.path.join(tmp_target.root_path, "store"), engine="zarr", chunks={}
Expand All @@ -51,11 +109,12 @@ def test_pyramid(
consolidated=False,
chunks={},
)

assert_isomorphic(pgf_dt, pyramid_datatree) # every node has same # of children
assert_isomorphic(
pgf_dt, pyramid_datatree_reproject
) # every node has same # of children
xr.testing.assert_allclose(
pgf_dt["0"].to_dataset(), pyramid_datatree["0"].to_dataset()
pgf_dt["0"].to_dataset(), pyramid_datatree_reproject["0"].to_dataset()
)
xr.testing.assert_allclose(
pgf_dt["1"].to_dataset(), pyramid_datatree["1"].to_dataset()
pgf_dt["1"].to_dataset(), pyramid_datatree_reproject["1"].to_dataset()
)

0 comments on commit 72e639a

Please sign in to comment.