Skip to content

Commit

Permalink
Added get_time_bounds and get_next_month to public API (#2214)
Browse files Browse the repository at this point in the history
  • Loading branch information
schlunma authored Oct 4, 2023
1 parent 49d9635 commit 13a444e
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 126 deletions.
2 changes: 1 addition & 1 deletion esmvalcore/cmor/_fixes/cmip6/iitm_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from esmvalcore.cmor._fixes.fix import get_time_bounds
from esmvalcore.cmor.fixes import get_time_bounds

from ..common import OceanFixGrid
from ..fix import Fix
Expand Down
2 changes: 1 addition & 1 deletion esmvalcore/cmor/_fixes/cmip6/kace_1_0_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from esmvalcore.cmor._fixes.fix import get_time_bounds
from esmvalcore.cmor.fixes import get_time_bounds

from ..common import ClFixHybridHeightCoord, OceanFixGrid
from ..fix import Fix
Expand Down
85 changes: 1 addition & 84 deletions esmvalcore/cmor/_fixes/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import tempfile
from collections.abc import Sequence
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

Expand All @@ -25,8 +24,8 @@
_get_single_cube,
_is_unstructured_grid,
)
from esmvalcore.cmor.fixes import get_time_bounds
from esmvalcore.cmor.table import get_var_info
from esmvalcore.iris_helpers import date2num

if TYPE_CHECKING:
from esmvalcore.cmor.table import CoordinateInfo, VariableInfo
Expand Down Expand Up @@ -325,88 +324,6 @@ def get_fixed_filepath(
return output_dir / Path(filepath).name


def get_next_month(month: int, year: int) -> tuple[int, int]:
"""Get next month and year.
Parameters
----------
month:
Current month.
year:
Current year.
Returns
-------
tuple[int, int]
Next month and next year.
"""
if month != 12:
return month + 1, year
return 1, year + 1


def get_time_bounds(time: Coord, freq: str):
"""Get bounds for time coordinate.
Parameters
----------
time:
Time coordinate.
freq:
Frequency.
Returns
-------
np.ndarray
Time bounds
Raises
------
NotImplementedError
Non-supported frequency is given.
"""
bounds = []
dates = time.units.num2date(time.points)
for step, date in enumerate(dates):
month = date.month
year = date.year
if freq in ['mon', 'mo']:
next_month, next_year = get_next_month(month, year)
min_bound = date2num(datetime(year, month, 1, 0, 0),
time.units, time.dtype)
max_bound = date2num(datetime(next_year, next_month, 1, 0, 0),
time.units, time.dtype)
elif freq == 'yr':
min_bound = date2num(datetime(year, 1, 1, 0, 0),
time.units, time.dtype)
max_bound = date2num(datetime(year + 1, 1, 1, 0, 0),
time.units, time.dtype)
elif freq == 'dec':
min_bound = date2num(datetime(year, 1, 1, 0, 0),
time.units, time.dtype)
max_bound = date2num(datetime(year + 10, 1, 1, 0, 0),
time.units, time.dtype)
else:
delta = {
'day': 12.0 / 24,
'6hr': 3.0 / 24,
'3hr': 1.5 / 24,
'1hr': 0.5 / 24,
}
if freq not in delta:
raise NotImplementedError(
f"Cannot guess time bounds for frequency '{freq}'"
)
point = time.points[step]
min_bound = point - delta[freq]
max_bound = point + delta[freq]
bounds.append([min_bound, max_bound])

return np.array(bounds)


class GenericFix(Fix):
"""Class providing generic fixes for all datasets."""

Expand Down
93 changes: 93 additions & 0 deletions esmvalcore/cmor/_fixes/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Shared functions for fixes."""
import logging
import os
from datetime import datetime
from functools import lru_cache

import dask.array as da
Expand All @@ -9,8 +10,11 @@
import pandas as pd
from cf_units import Unit
from iris import NameConstraint
from iris.coords import Coord
from scipy.interpolate import interp1d

from esmvalcore.iris_helpers import date2num

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -411,3 +415,92 @@ def fix_ocean_depth_coord(cube):
depth_coord.units = 'm'
depth_coord.long_name = 'ocean depth coordinate'
depth_coord.attributes = {'positive': 'down'}


def get_next_month(month: int, year: int) -> tuple[int, int]:
"""Get next month and year.
Parameters
----------
month:
Current month.
year:
Current year.
Returns
-------
tuple[int, int]
Next month and next year.
"""
if month != 12:
return month + 1, year
return 1, year + 1


def get_time_bounds(time: Coord, freq: str) -> np.ndarray:
"""Get bounds for time coordinate.
For monthly data, use the first day of the current month and the first day
of the next month. For yearly or decadal data, use 1 January of the current
year and 1 January of the next year or 10 years from the current year. For
other frequencies (daily, 6-hourly, 3-hourly, hourly), half of the
frequency is subtracted/added from the current point in time to get the
bounds.
Parameters
----------
time:
Time coordinate.
freq:
Frequency.
Returns
-------
np.ndarray
Time bounds
Raises
------
NotImplementedError
Non-supported frequency is given.
"""
bounds = []
dates = time.units.num2date(time.points)
for step, date in enumerate(dates):
month = date.month
year = date.year
if freq in ['mon', 'mo']:
next_month, next_year = get_next_month(month, year)
min_bound = date2num(datetime(year, month, 1, 0, 0),
time.units, time.dtype)
max_bound = date2num(datetime(next_year, next_month, 1, 0, 0),
time.units, time.dtype)
elif freq == 'yr':
min_bound = date2num(datetime(year, 1, 1, 0, 0),
time.units, time.dtype)
max_bound = date2num(datetime(year + 1, 1, 1, 0, 0),
time.units, time.dtype)
elif freq == 'dec':
min_bound = date2num(datetime(year, 1, 1, 0, 0),
time.units, time.dtype)
max_bound = date2num(datetime(year + 10, 1, 1, 0, 0),
time.units, time.dtype)
else:
delta = {
'day': 12.0 / 24,
'6hr': 3.0 / 24,
'3hr': 1.5 / 24,
'1hr': 0.5 / 24,
}
if freq not in delta:
raise NotImplementedError(
f"Cannot guess time bounds for frequency '{freq}'"
)
point = time.points[step]
min_bound = point - delta[freq]
max_bound = point + delta[freq]
bounds.append([min_bound, max_bound])

return np.array(bounds)
9 changes: 8 additions & 1 deletion esmvalcore/cmor/fixes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Functions for fixing specific issues with datasets."""

from ._fixes.shared import add_altitude_from_plev, add_plev_from_altitude
from ._fixes.shared import (
add_altitude_from_plev,
add_plev_from_altitude,
get_next_month,
get_time_bounds,
)

__all__ = [
'add_altitude_from_plev',
'add_plev_from_altitude',
'get_time_bounds',
'get_next_month',
]
2 changes: 1 addition & 1 deletion esmvalcore/preprocessor/_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from iris.cube import Cube, CubeList
from iris.time import PartialDateTime

from esmvalcore.cmor._fixes.fix import get_next_month, get_time_bounds
from esmvalcore.cmor.fixes import get_next_month, get_time_bounds
from esmvalcore.iris_helpers import date2num

from ._shared import get_iris_analysis_operation, operator_accept_weights
Expand Down
38 changes: 38 additions & 0 deletions tests/integration/cmor/_fixes/test_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
from cf_units import Unit
from iris import NameConstraint
from iris.coords import AuxCoord

from esmvalcore.cmor._fixes.shared import (
_map_on_filled,
Expand All @@ -25,6 +26,7 @@
get_altitude_to_pressure_func,
get_bounds_cube,
get_pressure_to_altitude_func,
get_time_bounds,
round_coordinates,
)

Expand Down Expand Up @@ -620,3 +622,39 @@ def test_fix_ocean_depth_coord():
assert depth_coord.units == 'm'
assert depth_coord.long_name == 'ocean depth coordinate'
assert depth_coord.attributes == {'positive': 'down'}


@pytest.fixture
def time_coord():
"""Time coordinate."""
time_coord = AuxCoord(
[15, 350],
standard_name='time',
units='days since 1850-01-01'
)
return time_coord


@pytest.mark.parametrize(
'freq,expected_bounds',
[
('mon', [[0, 31], [334, 365]]),
('mo', [[0, 31], [334, 365]]),
('yr', [[0, 365], [0, 365]]),
('dec', [[0, 3652], [0, 3652]]),
('day', [[14.5, 15.5], [349.5, 350.5]]),
('6hr', [[14.875, 15.125], [349.875, 350.125]]),
('3hr', [[14.9375, 15.0625], [349.9375, 350.0625]]),
('1hr', [[14.97916666, 15.020833333], [349.97916666, 350.020833333]]),
]
)
def test_get_time_bounds(time_coord, freq, expected_bounds):
"""Test ``get_time_bounds`."""
bounds = get_time_bounds(time_coord, freq)
np.testing.assert_allclose(bounds, expected_bounds)


def test_get_time_bounds_invalid_freq_fail(time_coord):
"""Test ``get_time_bounds`."""
with pytest.raises(NotImplementedError):
get_time_bounds(time_coord, 'invalid_freq')
2 changes: 2 additions & 0 deletions tests/unit/cmor/test_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
@pytest.mark.parametrize('func', [
'add_altitude_from_plev',
'add_plev_from_altitude',
'get_next_month',
'get_time_bounds',
])
def test_imports(func):
assert func in fixes.__all__
Expand Down
39 changes: 1 addition & 38 deletions tests/unit/cmor/test_generic_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,14 @@

from unittest.mock import sentinel

import numpy as np
import pytest
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList

from esmvalcore.cmor._fixes.fix import GenericFix, get_time_bounds
from esmvalcore.cmor._fixes.fix import GenericFix
from esmvalcore.cmor.table import get_var_info


@pytest.fixture
def time_coord():
"""Time coordinate."""
time_coord = AuxCoord(
[15, 350],
standard_name='time',
units='days since 1850-01-01'
)
return time_coord


@pytest.fixture
def generic_fix():
"""Generic fix object."""
Expand All @@ -30,31 +18,6 @@ def generic_fix():
return GenericFix(vardef, extra_facets=extra_facets)


@pytest.mark.parametrize(
'freq,expected_bounds',
[
('mon', [[0, 31], [334, 365]]),
('mo', [[0, 31], [334, 365]]),
('yr', [[0, 365], [0, 365]]),
('dec', [[0, 3652], [0, 3652]]),
('day', [[14.5, 15.5], [349.5, 350.5]]),
('6hr', [[14.875, 15.125], [349.875, 350.125]]),
('3hr', [[14.9375, 15.0625], [349.9375, 350.0625]]),
('1hr', [[14.97916666, 15.020833333], [349.97916666, 350.020833333]]),
]
)
def test_get_time_bounds(time_coord, freq, expected_bounds):
"""Test ``get_time_bounds`."""
bounds = get_time_bounds(time_coord, freq)
np.testing.assert_allclose(bounds, expected_bounds)


def test_get_time_bounds_invalid_freq_fail(time_coord):
"""Test ``get_time_bounds`."""
with pytest.raises(NotImplementedError):
get_time_bounds(time_coord, 'invalid_freq')


def test_generic_fix_empty_long_name(generic_fix, monkeypatch):
"""Test ``GenericFix``."""
# Artificially set long_name to empty string for test
Expand Down

0 comments on commit 13a444e

Please sign in to comment.