diff --git a/esmvalcore/cmor/_fixes/cmip6/iitm_esm.py b/esmvalcore/cmor/_fixes/cmip6/iitm_esm.py index a6dee710c9..6ed2108ff7 100644 --- a/esmvalcore/cmor/_fixes/cmip6/iitm_esm.py +++ b/esmvalcore/cmor/_fixes/cmip6/iitm_esm.py @@ -3,7 +3,7 @@ import numpy as np -from esmvalcore.cmor._fixes.fix import get_time_bounds +from esmvalcore.cmor.fixes import get_time_bounds from ..common import OceanFixGrid from ..fix import Fix diff --git a/esmvalcore/cmor/_fixes/cmip6/kace_1_0_g.py b/esmvalcore/cmor/_fixes/cmip6/kace_1_0_g.py index 5e27377b0d..e4c2cc420e 100644 --- a/esmvalcore/cmor/_fixes/cmip6/kace_1_0_g.py +++ b/esmvalcore/cmor/_fixes/cmip6/kace_1_0_g.py @@ -3,7 +3,7 @@ import numpy as np -from esmvalcore.cmor._fixes.fix import get_time_bounds +from esmvalcore.cmor.fixes import get_time_bounds from ..common import ClFixHybridHeightCoord, OceanFixGrid from ..fix import Fix diff --git a/esmvalcore/cmor/_fixes/fix.py b/esmvalcore/cmor/_fixes/fix.py index 4b5163d4a8..a6156a231c 100644 --- a/esmvalcore/cmor/_fixes/fix.py +++ b/esmvalcore/cmor/_fixes/fix.py @@ -6,7 +6,6 @@ import logging import tempfile from collections.abc import Sequence -from datetime import datetime from pathlib import Path from typing import TYPE_CHECKING, Any, Optional @@ -25,8 +24,8 @@ _get_single_cube, _is_unstructured_grid, ) +from esmvalcore.cmor.fixes import get_time_bounds from esmvalcore.cmor.table import get_var_info -from esmvalcore.iris_helpers import date2num if TYPE_CHECKING: from esmvalcore.cmor.table import CoordinateInfo, VariableInfo @@ -325,88 +324,6 @@ def get_fixed_filepath( return output_dir / Path(filepath).name -def get_next_month(month: int, year: int) -> tuple[int, int]: - """Get next month and year. - - Parameters - ---------- - month: - Current month. - year: - Current year. - - Returns - ------- - tuple[int, int] - Next month and next year. - - """ - if month != 12: - return month + 1, year - return 1, year + 1 - - -def get_time_bounds(time: Coord, freq: str): - """Get bounds for time coordinate. - - Parameters - ---------- - time: - Time coordinate. - freq: - Frequency. - - Returns - ------- - np.ndarray - Time bounds - - Raises - ------ - NotImplementedError - Non-supported frequency is given. - - """ - bounds = [] - dates = time.units.num2date(time.points) - for step, date in enumerate(dates): - month = date.month - year = date.year - if freq in ['mon', 'mo']: - next_month, next_year = get_next_month(month, year) - min_bound = date2num(datetime(year, month, 1, 0, 0), - time.units, time.dtype) - max_bound = date2num(datetime(next_year, next_month, 1, 0, 0), - time.units, time.dtype) - elif freq == 'yr': - min_bound = date2num(datetime(year, 1, 1, 0, 0), - time.units, time.dtype) - max_bound = date2num(datetime(year + 1, 1, 1, 0, 0), - time.units, time.dtype) - elif freq == 'dec': - min_bound = date2num(datetime(year, 1, 1, 0, 0), - time.units, time.dtype) - max_bound = date2num(datetime(year + 10, 1, 1, 0, 0), - time.units, time.dtype) - else: - delta = { - 'day': 12.0 / 24, - '6hr': 3.0 / 24, - '3hr': 1.5 / 24, - '1hr': 0.5 / 24, - } - if freq not in delta: - raise NotImplementedError( - f"Cannot guess time bounds for frequency '{freq}'" - ) - point = time.points[step] - min_bound = point - delta[freq] - max_bound = point + delta[freq] - bounds.append([min_bound, max_bound]) - - return np.array(bounds) - - class GenericFix(Fix): """Class providing generic fixes for all datasets.""" diff --git a/esmvalcore/cmor/_fixes/shared.py b/esmvalcore/cmor/_fixes/shared.py index 0e4dfead08..55ccc8697f 100644 --- a/esmvalcore/cmor/_fixes/shared.py +++ b/esmvalcore/cmor/_fixes/shared.py @@ -1,6 +1,7 @@ """Shared functions for fixes.""" import logging import os +from datetime import datetime from functools import lru_cache import dask.array as da @@ -9,8 +10,11 @@ import pandas as pd from cf_units import Unit from iris import NameConstraint +from iris.coords import Coord from scipy.interpolate import interp1d +from esmvalcore.iris_helpers import date2num + logger = logging.getLogger(__name__) @@ -411,3 +415,92 @@ def fix_ocean_depth_coord(cube): depth_coord.units = 'm' depth_coord.long_name = 'ocean depth coordinate' depth_coord.attributes = {'positive': 'down'} + + +def get_next_month(month: int, year: int) -> tuple[int, int]: + """Get next month and year. + + Parameters + ---------- + month: + Current month. + year: + Current year. + + Returns + ------- + tuple[int, int] + Next month and next year. + + """ + if month != 12: + return month + 1, year + return 1, year + 1 + + +def get_time_bounds(time: Coord, freq: str) -> np.ndarray: + """Get bounds for time coordinate. + + For monthly data, use the first day of the current month and the first day + of the next month. For yearly or decadal data, use 1 January of the current + year and 1 January of the next year or 10 years from the current year. For + other frequencies (daily, 6-hourly, 3-hourly, hourly), half of the + frequency is subtracted/added from the current point in time to get the + bounds. + + Parameters + ---------- + time: + Time coordinate. + freq: + Frequency. + + Returns + ------- + np.ndarray + Time bounds + + Raises + ------ + NotImplementedError + Non-supported frequency is given. + + """ + bounds = [] + dates = time.units.num2date(time.points) + for step, date in enumerate(dates): + month = date.month + year = date.year + if freq in ['mon', 'mo']: + next_month, next_year = get_next_month(month, year) + min_bound = date2num(datetime(year, month, 1, 0, 0), + time.units, time.dtype) + max_bound = date2num(datetime(next_year, next_month, 1, 0, 0), + time.units, time.dtype) + elif freq == 'yr': + min_bound = date2num(datetime(year, 1, 1, 0, 0), + time.units, time.dtype) + max_bound = date2num(datetime(year + 1, 1, 1, 0, 0), + time.units, time.dtype) + elif freq == 'dec': + min_bound = date2num(datetime(year, 1, 1, 0, 0), + time.units, time.dtype) + max_bound = date2num(datetime(year + 10, 1, 1, 0, 0), + time.units, time.dtype) + else: + delta = { + 'day': 12.0 / 24, + '6hr': 3.0 / 24, + '3hr': 1.5 / 24, + '1hr': 0.5 / 24, + } + if freq not in delta: + raise NotImplementedError( + f"Cannot guess time bounds for frequency '{freq}'" + ) + point = time.points[step] + min_bound = point - delta[freq] + max_bound = point + delta[freq] + bounds.append([min_bound, max_bound]) + + return np.array(bounds) diff --git a/esmvalcore/cmor/fixes.py b/esmvalcore/cmor/fixes.py index e5931b0f0f..534aa3bd94 100644 --- a/esmvalcore/cmor/fixes.py +++ b/esmvalcore/cmor/fixes.py @@ -1,8 +1,15 @@ """Functions for fixing specific issues with datasets.""" -from ._fixes.shared import add_altitude_from_plev, add_plev_from_altitude +from ._fixes.shared import ( + add_altitude_from_plev, + add_plev_from_altitude, + get_next_month, + get_time_bounds, +) __all__ = [ 'add_altitude_from_plev', 'add_plev_from_altitude', + 'get_time_bounds', + 'get_next_month', ] diff --git a/esmvalcore/preprocessor/_time.py b/esmvalcore/preprocessor/_time.py index 6c363379e5..f39566e51f 100644 --- a/esmvalcore/preprocessor/_time.py +++ b/esmvalcore/preprocessor/_time.py @@ -23,7 +23,7 @@ from iris.cube import Cube, CubeList from iris.time import PartialDateTime -from esmvalcore.cmor._fixes.fix import get_next_month, get_time_bounds +from esmvalcore.cmor.fixes import get_next_month, get_time_bounds from esmvalcore.iris_helpers import date2num from ._shared import get_iris_analysis_operation, operator_accept_weights diff --git a/tests/integration/cmor/_fixes/test_shared.py b/tests/integration/cmor/_fixes/test_shared.py index 92d2cfba2a..114dec7e0a 100644 --- a/tests/integration/cmor/_fixes/test_shared.py +++ b/tests/integration/cmor/_fixes/test_shared.py @@ -7,6 +7,7 @@ import pytest from cf_units import Unit from iris import NameConstraint +from iris.coords import AuxCoord from esmvalcore.cmor._fixes.shared import ( _map_on_filled, @@ -25,6 +26,7 @@ get_altitude_to_pressure_func, get_bounds_cube, get_pressure_to_altitude_func, + get_time_bounds, round_coordinates, ) @@ -620,3 +622,39 @@ def test_fix_ocean_depth_coord(): assert depth_coord.units == 'm' assert depth_coord.long_name == 'ocean depth coordinate' assert depth_coord.attributes == {'positive': 'down'} + + +@pytest.fixture +def time_coord(): + """Time coordinate.""" + time_coord = AuxCoord( + [15, 350], + standard_name='time', + units='days since 1850-01-01' + ) + return time_coord + + +@pytest.mark.parametrize( + 'freq,expected_bounds', + [ + ('mon', [[0, 31], [334, 365]]), + ('mo', [[0, 31], [334, 365]]), + ('yr', [[0, 365], [0, 365]]), + ('dec', [[0, 3652], [0, 3652]]), + ('day', [[14.5, 15.5], [349.5, 350.5]]), + ('6hr', [[14.875, 15.125], [349.875, 350.125]]), + ('3hr', [[14.9375, 15.0625], [349.9375, 350.0625]]), + ('1hr', [[14.97916666, 15.020833333], [349.97916666, 350.020833333]]), + ] +) +def test_get_time_bounds(time_coord, freq, expected_bounds): + """Test ``get_time_bounds`.""" + bounds = get_time_bounds(time_coord, freq) + np.testing.assert_allclose(bounds, expected_bounds) + + +def test_get_time_bounds_invalid_freq_fail(time_coord): + """Test ``get_time_bounds`.""" + with pytest.raises(NotImplementedError): + get_time_bounds(time_coord, 'invalid_freq') diff --git a/tests/unit/cmor/test_fixes.py b/tests/unit/cmor/test_fixes.py index 24e4acde52..c1a87e1bc1 100644 --- a/tests/unit/cmor/test_fixes.py +++ b/tests/unit/cmor/test_fixes.py @@ -8,6 +8,8 @@ @pytest.mark.parametrize('func', [ 'add_altitude_from_plev', 'add_plev_from_altitude', + 'get_next_month', + 'get_time_bounds', ]) def test_imports(func): assert func in fixes.__all__ diff --git a/tests/unit/cmor/test_generic_fix.py b/tests/unit/cmor/test_generic_fix.py index 54703b93f5..4294fe0752 100644 --- a/tests/unit/cmor/test_generic_fix.py +++ b/tests/unit/cmor/test_generic_fix.py @@ -2,26 +2,14 @@ from unittest.mock import sentinel -import numpy as np import pytest from iris.coords import AuxCoord from iris.cube import Cube, CubeList -from esmvalcore.cmor._fixes.fix import GenericFix, get_time_bounds +from esmvalcore.cmor._fixes.fix import GenericFix from esmvalcore.cmor.table import get_var_info -@pytest.fixture -def time_coord(): - """Time coordinate.""" - time_coord = AuxCoord( - [15, 350], - standard_name='time', - units='days since 1850-01-01' - ) - return time_coord - - @pytest.fixture def generic_fix(): """Generic fix object.""" @@ -30,31 +18,6 @@ def generic_fix(): return GenericFix(vardef, extra_facets=extra_facets) -@pytest.mark.parametrize( - 'freq,expected_bounds', - [ - ('mon', [[0, 31], [334, 365]]), - ('mo', [[0, 31], [334, 365]]), - ('yr', [[0, 365], [0, 365]]), - ('dec', [[0, 3652], [0, 3652]]), - ('day', [[14.5, 15.5], [349.5, 350.5]]), - ('6hr', [[14.875, 15.125], [349.875, 350.125]]), - ('3hr', [[14.9375, 15.0625], [349.9375, 350.0625]]), - ('1hr', [[14.97916666, 15.020833333], [349.97916666, 350.020833333]]), - ] -) -def test_get_time_bounds(time_coord, freq, expected_bounds): - """Test ``get_time_bounds`.""" - bounds = get_time_bounds(time_coord, freq) - np.testing.assert_allclose(bounds, expected_bounds) - - -def test_get_time_bounds_invalid_freq_fail(time_coord): - """Test ``get_time_bounds`.""" - with pytest.raises(NotImplementedError): - get_time_bounds(time_coord, 'invalid_freq') - - def test_generic_fix_empty_long_name(generic_fix, monkeypatch): """Test ``GenericFix``.""" # Artificially set long_name to empty string for test