diff --git a/.github/workflows/beam.yaml b/.github/workflows/beam.yaml new file mode 100644 index 0000000..032188d --- /dev/null +++ b/.github/workflows/beam.yaml @@ -0,0 +1,40 @@ +name: Test beam + +on: + push: + branches: [ "main" ] + pull_request: + branches: [ "main" ] + +env: + PYTEST_ADDOPTS: "--color=yes" + +jobs: + test: + name: py${{ matrix.python-version }} + runs-on: "ubuntu-latest" + strategy: + fail-fast: false + matrix: + python-version: ["3.10"] + steps: + - name: 🛍️ Checkout + uses: actions/checkout@v4 + - name: 🔁 Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: 🐝 Install deps + shell: bash -l {0} + run: | + python -m pip install --upgrade pip + python -m pip install -r ci/requirements-beam.txt + python -m pip install -e . --no-deps + python -m pip install pytest + - name: 🐍 List env + shell: bash -l {0} + run: | + python -m pip list + - name: 🏃 Run tests + shell: bash -l {0} + run: pytest -vv tests/test_beam.py diff --git a/ci/requirements-beam.txt b/ci/requirements-beam.txt new file mode 100644 index 0000000..ae8205e --- /dev/null +++ b/ci/requirements-beam.txt @@ -0,0 +1,3 @@ +apache-beam==2.51.0 +xarray==2023.10.0 +xarray-beam==0.6.2 diff --git a/scale_aware_air_sea/beam.py b/scale_aware_air_sea/beam.py new file mode 100644 index 0000000..5d589f9 --- /dev/null +++ b/scale_aware_air_sea/beam.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +import apache_beam as beam + +input_data = [("cesm", "cesm_ds"), ("cm26", "cm26_ds")] + + +def append_val(t: tuple, val: str) -> tuple: + k, v = t + return (k, v + val) + + +class PangeoForgeRecipe(beam.PTransform): + def expand(self, pcoll: beam.PCollection[tuple]): + return pcoll | beam.Map(append_val, val="_arco") + + +class XBeamFilter(beam.PTransform): + def expand(self, pcoll: beam.PCollection[tuple]): + return pcoll | beam.Map(append_val, val="_filtered") + + +def add_spec(t: tuple, spec: str): + k, v = t + merged = "+".join(list(v)) + return (k, f"${merged}$_{spec}") + + +@dataclass +class MixVariables(beam.PTransform): + spec: str + + def expand(self, pcoll: beam.PCollection[tuple]): + return pcoll | beam.Map(add_spec, spec=self.spec) + + +def flatten_tuple(t: tuple): + k, v = t + arco, filtered = v + return (k, (arco[0], filtered[0])) + + +class XBeamComputeFluxes(beam.PTransform): + def expand(self, pcoll: beam.PCollection[tuple]): + # return pcoll | xbeam.Something() + return pcoll | beam.Map(append_val, val="_flux") + + +class AirSeaPaper(beam.PTransform): + def expand(self, pcoll): + arco = pcoll | PangeoForgeRecipe() # -> Zarr(data_vars={a, b, c}) + filtered = arco | XBeamFilter() # -> Zarr(data_vars={a_f, b_f, c_f}) + nested = (arco, filtered) | beam.CoGroupByKey() | beam.Map(flatten_tuple) + + a_b_c = nested | "mix 0" >> MixVariables(spec="a,b,c") + a_b_cf = nested | "mix 1" >> MixVariables(spec="a,b,cf") + a_bf_cf = nested | "mix 2" >> MixVariables(spec="a,bf,cf") + fluxes = (a_b_c, a_b_cf, a_bf_cf) | beam.Flatten() | XBeamComputeFluxes() + return fluxes + + +if __name__ == "__main__": + with beam.Pipeline() as p: + p | beam.Create(input_data) | AirSeaPaper() | beam.Map(print) diff --git a/tests/test_beam.py b/tests/test_beam.py new file mode 100644 index 0000000..e71a957 --- /dev/null +++ b/tests/test_beam.py @@ -0,0 +1,28 @@ +import apache_beam as beam +import pytest +from apache_beam.testing import test_pipeline +from apache_beam.testing.util import assert_that, equal_to +from scale_aware_air_sea.beam import AirSeaPaper, input_data + + +runners = ["DirectRunner"] +runner_ids = ["DirectRunner"] + + +@pytest.fixture +def expected(): + return [ + ("cesm", "$cesm_ds_arco+cesm_ds_arco_filtered$_a,b,c_flux"), + ("cm26", "$cm26_ds_arco+cm26_ds_arco_filtered$_a,b,c_flux"), + ("cesm", "$cesm_ds_arco+cesm_ds_arco_filtered$_a,bf,cf_flux"), + ("cm26", "$cm26_ds_arco+cm26_ds_arco_filtered$_a,bf,cf_flux"), + ("cesm", "$cesm_ds_arco+cesm_ds_arco_filtered$_a,b,cf_flux"), + ("cm26", "$cm26_ds_arco+cm26_ds_arco_filtered$_a,b,cf_flux"), + ] + + +@pytest.mark.parametrize("runner", runners, ids=runner_ids) +def test_zero_equals_one(runner, expected): + with test_pipeline.TestPipeline(runner=runner) as p: + pcoll = p | beam.Create(input_data) | AirSeaPaper() + assert_that(pcoll, equal_to(expected))