From 72e639a28bd2a1bb109f1be3da12ec3eac5903b2 Mon Sep 17 00:00:00 2001 From: Raphael Hagen Date: Mon, 10 Jun 2024 13:33:59 -0600 Subject: [PATCH] wip upates --- pangeo_forge_ndpyramid/common.py | 8 +-- pangeo_forge_ndpyramid/transforms.py | 9 ++- tests/conftest.py | 20 +++++- tests/data_generation.py | 33 ++++++++-- tests/test_end_to_end.py | 95 ++++++++++++++++++++++------ 5 files changed, 133 insertions(+), 32 deletions(-) diff --git a/pangeo_forge_ndpyramid/common.py b/pangeo_forge_ndpyramid/common.py index fd360e2..b1273d7 100644 --- a/pangeo_forge_ndpyramid/common.py +++ b/pangeo_forge_ndpyramid/common.py @@ -1,7 +1,7 @@ from typing import Literal, Optional import xarray as xr -import zarr +import zarr # type: ignore def create_pyramid( @@ -13,7 +13,7 @@ def create_pyramid( pyramid_method: Literal["coarsen", "regrid", "reproject", "resample"] = "reproject", pyramid_kwargs: Optional[dict] = {}, ) -> zarr.storage.FSStore: - from ndpyramid.utils import set_zarr_encoding + from ndpyramid.utils import set_zarr_encoding # type: ignore if pyramid_method not in ["reproject", "resample"]: raise NotImplementedError( @@ -32,12 +32,12 @@ def create_pyramid( ds = ds.rename(rename_spatial_dims) if pyramid_method == "reproject": - from ndpyramid.reproject import level_reproject + from ndpyramid.reproject import level_reproject # type: ignore level_ds = level_reproject(ds, level=level, **pyramid_kwargs) elif pyramid_method == "resample": - from ndpyramid.resample import level_resample + from ndpyramid.resample import level_resample # type: ignore # Should we have resample specific kwargs we pass here? level_ds = level_resample( diff --git a/pangeo_forge_ndpyramid/transforms.py b/pangeo_forge_ndpyramid/transforms.py index a2eee2e..44d079f 100644 --- a/pangeo_forge_ndpyramid/transforms.py +++ b/pangeo_forge_ndpyramid/transforms.py @@ -3,7 +3,7 @@ import apache_beam as beam import xarray as xr -import zarr +import zarr # type: ignore from pangeo_forge_recipes.patterns import Dimension, Index from pangeo_forge_recipes.storage import FSSpecTarget from pangeo_forge_recipes.transforms import ( @@ -83,7 +83,7 @@ def expand( datasets: beam.PCollection[Tuple[Index, xr.Dataset]], ) -> beam.PCollection[zarr.storage.FSStore]: # Add multiscales metadata to the root of the target store - from ndpyramid.utils import get_version, multiscales_template + from ndpyramid.utils import get_version, multiscales_template # type: ignore save_kwargs = {"levels": self.levels, "pixels_per_tile": self.pixels_per_tile} attrs = { @@ -98,13 +98,16 @@ def expand( kwargs=save_kwargs, ) } + # from StoreToZarr + # target_chunks: 'Dict[str, int]' = , chunks = {"x": self.pixels_per_tile, "y": self.pixels_per_tile} if self.other_chunks is not None: chunks |= self.other_chunks ds = xr.Dataset(attrs=attrs) - target_path = (self.target_root / self.store_name).get_mapper() + # Note: mypy typing in not happy here. + target_path = (self.target_root / self.store_name).get_mapper() # type: ignore ds.to_zarr(store=target_path, compute=False) # noqa # generate all pyramid levels diff --git a/tests/conftest.py b/tests/conftest.py index 6ca70af..25023cb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from pangeo_forge_recipes.patterns import pattern_from_file_sequence from pangeo_forge_recipes.storage import FSSpecTarget -from .data_generation import make_pyramid +from .data_generation import make_pyramid_resample, make_pyramid_reproject @pytest.fixture @@ -23,8 +23,13 @@ def tmp_target(tmpdir_factory): @pytest.fixture(scope="session") -def pyramid_datatree(levels: int = 2): - return make_pyramid(levels=levels) +def pyramid_datatree_reproject(levels: int = 2): + return make_pyramid_reproject(levels=levels) + + +@pytest.fixture(scope="session") +def pyramid_datatree_resample(levels: int = 2): + return make_pyramid_resample(levels=levels) @pytest.fixture(scope="session") @@ -34,3 +39,12 @@ def create_file_pattern(): concat_dim="time", nitems_per_file=1, ) + + +@pytest.fixture(scope="session") +def create_file_pattern_gpm_imerg(): + return pattern_from_file_sequence( + [str(path) for path in ["tests/data/3B-DAY.MS.MRG.3IMERG.20230729-S000000-E235959.V07B.nc4","tests/data/3B-DAY.MS.MRG.3IMERG.20230730-S000000-E235959.V07B.nc4"]], + concat_dim="time", + nitems_per_file=1, + ) diff --git a/tests/data_generation.py b/tests/data_generation.py index ab1d6c7..7534361 100644 --- a/tests/data_generation.py +++ b/tests/data_generation.py @@ -1,13 +1,38 @@ import rioxarray # noqa import xarray as xr -from ndpyramid import pyramid_reproject +from ndpyramid import pyramid_reproject, pyramid_resample -def make_pyramid(levels: int): +def make_pyramid_resample(levels: int): + # ds = xr.tutorial.open_dataset("air_temperature") + + files = [ + 'tests/data/3B-DAY.MS.MRG.3IMERG.20230729-S000000-E235959.V07B.nc4', + 'tests/data/3B-DAY.MS.MRG.3IMERG.20230730-S000000-E235959.V07B.nc4' + ] + ds = xr.open_mfdataset(files, engine='netcdf4', drop_variables=["time_bnds"],decode_coords="all") + + ds = ds.chunk({'lat':10, 'lon':10}) + ds = ds[['precipitation']] + + + ds = ds.rio.write_crs('EPSG:4326') + ds = ds.drop_vars('spatial_ref') + + ds = ds.transpose('time', 'lat', 'lon') + + # ds = ds.rename({"lon": "longitude", "lat": "latitude"}) + # ds = ds.rio.write_crs("EPSG:4326") + ds = ds.isel(time=slice(0, 2)) + # import pdb; pdb.set_trace() + return pyramid_resample(ds, levels=levels, x='lon', y='lat') + + +def make_pyramid_reproject(levels: int): ds = xr.tutorial.open_dataset("air_temperature") + ds = ds.rename({"lon": "longitude", "lat": "latitude"}) ds = ds.rio.write_crs("EPSG:4326") ds = ds.isel(time=slice(0, 2)) - # other_chunks added to e2e pass of pyramid b/c target_chunks invert_meshgrid error - return pyramid_reproject(ds, levels=levels, other_chunks={"time": 1}) + return pyramid_reproject(ds, levels=levels) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 95d87d5..f97c74a 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,16 +1,71 @@ import os +from dataclasses import dataclass import apache_beam as beam +import datatree as dt import xarray as xr +from datatree.testing import assert_isomorphic from pangeo_forge_recipes.transforms import OpenWithXarray, StoreToZarr from pangeo_forge_ndpyramid.transforms import StoreToPyramid -# TODO: We should parameterize the reprojection methods available in ndpyramid -# TODO: Test names and attrs -def test_pyramid( - pyramid_datatree, +def test_pyramid_resample( + pyramid_datatree_resample, create_file_pattern_gpm_imerg, pipeline, tmp_target +): + pattern = create_file_pattern_gpm_imerg + @dataclass + class Transpose(beam.PTransform): + """Transpose dim order for pyresample""" + + def _transpose(self, ds: xr.Dataset) -> xr.Dataset: + ds = ds[['precipitation']] + ds = ds.transpose("time", "lat", "lon") + + return ds + + def expand(self, pcoll): + return pcoll | "Transpose" >> beam.MapTuple( + lambda k, v: (k, self._transpose(v)) + ) + + with pipeline as p: + ( + p + | beam.Create(pattern.items()) + | OpenWithXarray(file_type=pattern.file_type) + | Transpose() + | "Write Pyramid Levels" + >> StoreToPyramid( + target_root=tmp_target, + store_name="pyramid", + levels=2, + epsg_code="4326", + pyramid_method="resample", + pyramid_kwargs={"x": "lon", "y": "lat"}, + combine_dims=pattern.combine_dim_keys, + ) + ) + pgf_dt = dt.open_datatree( + os.path.join(tmp_target.root_path, "pyramid"), + engine="zarr", + consolidated=False, + chunks={}, + ) + assert_isomorphic( + pgf_dt, pyramid_datatree_resample + ) # every node has same # of children + xr.testing.assert_allclose( + pgf_dt["0"].to_dataset(), pyramid_datatree_resample["0"].to_dataset() + ) + xr.testing.assert_allclose( + pgf_dt["1"].to_dataset(), pyramid_datatree_resample["1"].to_dataset() + ) + + +# TODO: Test names and attrs +def test_pyramid_reproject( + pyramid_datatree_reproject, create_file_pattern, pipeline, tmp_target, @@ -29,17 +84,20 @@ def test_pyramid( store_name="store", combine_dims=pattern.combine_dim_keys, ) - process | "Write Pyramid Levels" >> StoreToPyramid( - target_root=tmp_target, - store_name="pyramid", - levels=2, - epsg_code="4326", - rename_spatial_dims={"lon": "longitude", "lat": "latitude"}, - combine_dims=pattern.combine_dim_keys, - ) - import datatree as dt - from datatree.testing import assert_isomorphic + ( + process + | "Write Pyramid Levels" + >> StoreToPyramid( + target_root=tmp_target, + store_name="pyramid", + levels=2, + epsg_code="4326", + pyramid_method="reproject", + rename_spatial_dims={"lon": "x", "lat": "y"}, + combine_dims=pattern.combine_dim_keys, + ) + ) assert xr.open_dataset( os.path.join(tmp_target.root_path, "store"), engine="zarr", chunks={} @@ -51,11 +109,12 @@ def test_pyramid( consolidated=False, chunks={}, ) - - assert_isomorphic(pgf_dt, pyramid_datatree) # every node has same # of children + assert_isomorphic( + pgf_dt, pyramid_datatree_reproject + ) # every node has same # of children xr.testing.assert_allclose( - pgf_dt["0"].to_dataset(), pyramid_datatree["0"].to_dataset() + pgf_dt["0"].to_dataset(), pyramid_datatree_reproject["0"].to_dataset() ) xr.testing.assert_allclose( - pgf_dt["1"].to_dataset(), pyramid_datatree["1"].to_dataset() + pgf_dt["1"].to_dataset(), pyramid_datatree_reproject["1"].to_dataset() )