Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ below:
- Ben Fitzpatrick (Met Office, UK)
- Tom Gale (Bureau of Meteorology, Australia)
- Sam Griffiths (Met Office, UK)
- Luke Hoffmann (Bureau of Meteorology, Australia)
- Ben Hooper (Met Office, UK)
- Aaron Hopkinson (Met Office, UK)
- Kathryn Howard (Met Office, UK)
Expand Down
19 changes: 19 additions & 0 deletions improver/calibration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def __init__(self):
)


def treelite_packages_available():
"""Return True if treelite packages are available, False otherwise."""
try:
import tl2cgen # noqa: F401
import treelite # noqa: F401
except ModuleNotFoundError:
return False
return True


def lightgbm_package_available():
"""Return True if LightGBM package is available, False otherwise."""
try:
import lightgbm # noqa: F401
except ModuleNotFoundError:
return False
return True


def split_forecasts_and_truth(
cubes: List[Cube], truth_attribute: str
) -> Tuple[Cube, Cube, Optional[Cube]]:
Expand Down
23 changes: 4 additions & 19 deletions improver/calibration/rainforest_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from numpy import ndarray

from improver import PostProcessingPlugin
from improver.calibration import (
lightgbm_package_available,
treelite_packages_available,
)
from improver.constants import MINUTES_IN_HOUR, SECONDS_IN_MINUTE
from improver.ensemble_copula_coupling.utilities import (
get_bounds_of_distribution,
Expand All @@ -39,25 +43,6 @@
Model = Literal["lightgbm_model", "treelite_model"]


def treelite_packages_available():
"""Return True if treelite packages are available, False otherwise."""
try:
import tl2cgen # noqa: F401
import treelite # noqa: F401
except ModuleNotFoundError:
return False
return True


def lightgbm_package_available():
"""Return True if LightGBM package is available, False otherwise."""
try:
import lightgbm # noqa: F401
except ModuleNotFoundError:
return False
return True


class ModelFileNotFoundError(Exception):
"""Used when the path to a treelite/LightGBM model object is invalid."""

Expand Down
77 changes: 77 additions & 0 deletions improver/calibration/rainforest_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.

# from pathlib import Path

from improver import BasePlugin
from improver.calibration import (
lightgbm_package_available,
treelite_packages_available,
)


class TrainRainForestsCalibration(BasePlugin):
lightgbm_params = {
"objective": "binary",
"num_leaves": 5,
"num_boost_round": 10,
"verbose": -1,
"seed": 0,
}

def __init__(self, training_data):
self.lightgbm_available = lightgbm_package_available()
if not self.lightgbm_available:
raise ModuleNotFoundError("Could not find LightGBM module")

self.training_data = training_data

def process(
self, threshold, observation_column, training_columns, output_path=None
):
"""Train a model for one threshold."""
import lightgbm

threshold_met = (self.training_data[observation_column] >= threshold).astype(
int
)
training_data = self.training_data[training_columns]
dataset = lightgbm.Dataset(training_data, label=threshold_met)

model = lightgbm.train(self.lightgbm_params, dataset)
if output_path:
model.save_model(output_path)

return model.model_to_string()


class CompileRainForestsCalibration(BasePlugin):
treelight_params = {"parallel_comp": 8, "quantize": 1}

def __init__(self):
self.treelite_available = treelite_packages_available()
if not self.treelite_available:
raise ModuleNotFoundError("Could not find TreeLite module")
# Also need lightGBM available to read in models
self.lightgbm_available = lightgbm_package_available()
if not self.lightgbm_available:
raise ModuleNotFoundError("Could not find LightGBM module")

def process(self, lightgbm_filepath, output_filepath):
"""Compile a lightgbm model."""
import tl2cgen
import treelite
from lightgbm import Booster

lightgbm_model = Booster(model_file=lightgbm_filepath)

model = treelite.Model.from_lightgbm(lightgbm_model)
tl2cgen.export_lib(
model,
toolchain="gcc",
libpath=output_filepath,
verbose=False,
params=self.treelight_params,
)
28 changes: 28 additions & 0 deletions improver/cli/compile_rainforests_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""CLI to compile a Rainforests calibration model."""

from pathlib import Path

from improver import cli


@cli.clizefy
def process(lightgbm_model: cli.inputpath):
"""
Train a set of Rainforests models.

"""

from improver.calibration.rainforest_training import CompileRainForestsCalibration

if not Path.is_file(lightgbm_model):
raise ValueError("--output_dir must be an existing file")

plugin = CompileRainForestsCalibration()

input_path = Path(lightgbm_model)
output_path = input_path.with_suffix(".o")
plugin.process(input_path, output_path)
41 changes: 41 additions & 0 deletions improver/cli/train_rainforests_calibration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
"""CLI to train Rainforests calibration models."""

from pathlib import Path

from improver import cli


@cli.clizefy
def process(
*,
training_data: cli.inputpath,
training_columns: cli.comma_separated_list,
observation_column: str,
thresholds: cli.comma_separated_list_of_float,
output_dir: cli.inputpath,
):
"""
Train Rainforests models for a particular leadtime, for a set of thresholds.

"""
import pandas as pd

from improver.calibration.rainforest_training import TrainRainForestsCalibration

if not Path.is_dir(output_dir):
raise ValueError("--output_dir must be a directory")

plugin = TrainRainForestsCalibration(pd.read_parquet(training_data))

for threshold in thresholds:
output_path = Path(output_dir) / f"model_{threshold}.txt"
plugin.process(
training_columns=training_columns,
observation_column=observation_column,
threshold=threshold,
output_path=output_path,
)
Empty file.
58 changes: 58 additions & 0 deletions improver_tests/calibration/rainforests_training/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.
import sys

import pytest

from improver.calibration import lightgbm_package_available, treelite_packages_available

from ..rainforests_calibration.conftest import (
deterministic_features,
deterministic_forecast,
dummy_lightgbm_models,
ensemble_features,
ensemble_forecast,
lead_times,
prepare_dummy_training_data,
thresholds,
)

_ = (
deterministic_features,
deterministic_forecast,
dummy_lightgbm_models,
ensemble_features,
ensemble_forecast,
lead_times,
prepare_dummy_training_data,
thresholds,
)

dummy_lightgbm_models = dummy_lightgbm_models


@pytest.fixture(params=[True, False])
def lightgbm_available(request, monkeypatch):
available = request.param and lightgbm_package_available()
if not available:
monkeypatch.setitem(sys.modules, "lightgbm", None)
return available


@pytest.fixture(params=[True, False])
def treelite_available(request, monkeypatch):
available = request.param and treelite_packages_available()
if not available:
monkeypatch.setitem(sys.modules, "treelite", None)
return available


@pytest.fixture
def deterministic_training_data(
deterministic_features, deterministic_forecast, lead_times
):
return prepare_dummy_training_data(
deterministic_features, deterministic_forecast, lead_times
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# (C) Crown Copyright, Met Office. All rights reserved.
#
# This file is part of 'IMPROVER' and is released under the BSD 3-Clause license.
# See LICENSE in the root of the repository for full licensing details.

from pathlib import Path

import pytest

from improver.calibration.rainforest_training import (
CompileRainForestsCalibration,
)

lightgbm = pytest.importorskip("lightgbm")
tl2cgen = pytest.importorskip("tl2cgen")
treelite = pytest.importorskip("treelite")


def test__init__(lightgbm_available, treelite_available, tmp_path):
"""Test class is created if treelight libraries are available.
Test class is not created if treelight libraries not available."""

if treelite_available and lightgbm_available:
expected_class = "CompileRainForestsCalibration"
result = CompileRainForestsCalibration()
assert type(result).__name__ == expected_class
else:
with pytest.raises(ModuleNotFoundError):
result = CompileRainForestsCalibration()


def test_process(dummy_lightgbm_models, tmp_path):
"""Test models are compiled."""

tree_models, lead_times, thresholds = dummy_lightgbm_models

compiler = CompileRainForestsCalibration()

model = tree_models[lead_times[0], thresholds[0]]
model_path = tmp_path / f"model{lead_times[0]}{thresholds[0]}.txt"
model.save_model(model_path)

compiled_path = tmp_path / f"compiled{lead_times[0]}{thresholds[0]}.o"
compiler.process(model_path, compiled_path)

assert Path.exists(compiled_path)
Loading