Skip to content

Commit

Permalink
feat: updated conformalization with None splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
M-Mouhcine committed Jan 2, 2024
1 parent f2c3f90 commit 28ce8bf
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 16 deletions.
47 changes: 37 additions & 10 deletions deel/puncc/api/conformalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from copy import deepcopy
from typing import Iterable
from typing import Tuple
from typing import Union

import numpy as np

Expand All @@ -36,18 +37,22 @@
from deel.puncc.api.prediction import BasePredictor
from deel.puncc.api.prediction import DualPredictor
from deel.puncc.api.splitting import BaseSplitter
from deel.puncc.api.splitting import IdSplitter

logger = logging.getLogger(__name__)


class ConformalPredictor:
"""Conformal predictor class.
:param deel.puncc.api.prediction.BasePredictor predictor: model wrapper.
:param deel.puncc.api.prediction.BasePredictor | object predictor:
underlying model to be conformalized. The model can directly be
passed as argument if it already has `fit` and `predict` methods.
:param deel.puncc.api.prediction.BaseCalibrator calibrator: nonconformity
computation strategy and set predictor.
:param deel.puncc.api.prediction.BaseSplitter splitter: fit/calibration
split strategy.
split strategy. The splitter can be set to None if the underlying
model is pretrained.
:param str method: method to handle the ensemble prediction and calibration
in case the splitter is a K-fold-like strategy. Defaults to 'cv+' to
follow cv+ procedure.
Expand Down Expand Up @@ -87,7 +92,8 @@ class ConformalPredictor:
# Regression linear model
model = linear_model.LinearRegression()
# Definition of a predictor
# Definition of a predictor. Note that it is not required to wrap
# the model here because it already implements fit and predict methods
predictor = BasePredictor(model)
# Definition of a calibrator, built for a given nonconformity scores
Expand Down Expand Up @@ -119,7 +125,7 @@ class ConformalPredictor:
def __init__(
self,
calibrator: BaseCalibrator,
predictor: BasePredictor,
predictor: Union[BasePredictor, object],
splitter: BaseSplitter,
method: str = "cv+",
train: bool = True,
Expand All @@ -129,23 +135,35 @@ def __init__(
if isinstance(predictor, (BasePredictor, DualPredictor)):
self.predictor = predictor

elif not hasattr(predictor, "fit"):
elif not hasattr(predictor, "predict"):
raise RuntimeError(
"Provided model has no fit method. "
"Provided model has no predict method. "
+ "Use a BasePredictor or a DualPredictor to build "
+ "a compatible predictor."
)

elif not hasattr(predictor, "predict"):
elif train and not hasattr(predictor, "fit"):
raise RuntimeError(
"Provided model has no predict method. "
"Provided model is not trained and has no fit method. "
+ "Use a BasePredictor or a DualPredictor to build "
+ "a compatible predictor."
)

else:
self.predictor = BasePredictor(predictor, is_trained=not train)

if train and splitter is None:
raise RuntimeError(
"The splitter argument is None but train is set to True. "
+ "Please provide a correct splitter to train the underlying "
+ "model."
)

if method != "cv+":
raise RuntimeError(
f"Method {method} is not implemented." + "Please choose 'cv+'."
)

self.splitter = splitter
self.method = method
self.train = train
Expand Down Expand Up @@ -208,7 +226,10 @@ def fit(
"""
# Get split folds. Each fold split is a iterable of a quadruple that
# contains fit and calibration data.
splits = self.splitter(X, y)
if self.splitter is None:
splits = IdSplitter(X, y, X, y)()
else:
splits = self.splitter(X, y)

# The Cross validation aggregator will aggregate the predictors and
# calibrators fitted on each of the K splits.
Expand Down Expand Up @@ -246,6 +267,12 @@ def fit(
calibrator = self.calibrator

if self.train:
if self.splitter is None:
raise RuntimeError(
"The splitter argument is None but train is set to "
+ "True. Please provide a correct splitter to train "
+ "the underlying model."
)
logger.info(f"Fitting model on fold {i+cached_len}")
predictor.fit(X_fit, y_fit, **kwargs) # Fit K-fold predictor

Expand Down Expand Up @@ -316,7 +343,7 @@ def load(path):
saved_dict = pickle.load(input_file)

loaded_cp = ConformalPredictor(
calibrator=None, predictor=BasePredictor(None), splitter=None
calibrator=None, predictor=BasePredictor(None), splitter=object()
)
loaded_cp.__dict__.update(saved_dict)
return loaded_cp
Expand Down
2 changes: 1 addition & 1 deletion deel/puncc/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
self.conformal_predictor = ConformalPredictor(
predictor=self.predictor,
calibrator=self.calibrator,
splitter=None,
splitter=object(),
train=self.train,
)

Expand Down
6 changes: 3 additions & 3 deletions deel/puncc/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
self.conformal_predictor = ConformalPredictor(
predictor=self.predictor,
calibrator=self.calibrator,
splitter=None,
splitter=object(),
train=train,
)

Expand Down Expand Up @@ -343,7 +343,7 @@ def __init__(
self.conformal_predictor = ConformalPredictor(
predictor=self.predictor,
calibrator=self.calibrator,
splitter=None,
splitter=object(),
train=train,
)

Expand Down Expand Up @@ -432,7 +432,7 @@ def __init__(self, predictor, *, train=True, weight_func=None):
self.conformal_predictor = ConformalPredictor(
predictor=self.predictor,
calibrator=self.calibrator,
splitter=None,
splitter=object(),
train=train,
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@

setuptools.setup(
name="puncc",
version="0.7.4",
version="0.7.5",
author=", ".join(["Mouhcine Mendil", "Luca Mossina", "Joseba Dalmau"]),
author_email=", ".join(
[
Expand Down
37 changes: 36 additions & 1 deletion tests/api/test_conformalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def setUp(self):
X, y, test_size=0.2, random_state=0
)

# Split train data in proper training and calibration
self.X_fit, self.X_calib, self.y_fit, self.y_calib = train_test_split(
self.X_train, self.y_train, test_size=0.5, random_state=0
)

# Regression linear model
model = linear_model.LinearRegression()

Expand Down Expand Up @@ -106,7 +111,7 @@ def test_bad_fit_predict(self):
def test_pretrained_predictor(self):
# Predictor initialized with trained model
model = linear_model.LinearRegression()
model.fit(self.X_train, self.y_train)
model.fit(self.X_fit, self.y_fit)
trained_predictor = BasePredictor(model, is_trained=True)
notrained_predictor = BasePredictor(model, is_trained=False)

Expand All @@ -131,6 +136,36 @@ def test_pretrained_predictor(self):
# Compute nonconformity scores
conformal_predictor.fit(self.X_train, self.y_train)

# Conformalization with no splitter (good)
conformal_predictor = ConformalPredictor(
predictor=trained_predictor,
calibrator=self.calibrator,
splitter=None,
train=False,
)
conformal_predictor.fit(self.X_calib, self.y_calib)
conformal_predictor.predict(self.X_test, alpha=0.1)

# Conformalization with no splitter (bad)
with self.assertRaises(RuntimeError):
conformal_predictor = ConformalPredictor(
predictor=notrained_predictor,
calibrator=self.calibrator,
splitter=None,
train=False,
)
conformal_predictor.fit(self.X_calib, self.y_calib)

# Conformalization with no splitter and train set to True (bad)
with self.assertRaises(RuntimeError):
conformal_predictor = ConformalPredictor(
predictor=notrained_predictor,
calibrator=self.calibrator,
splitter=None,
train=True,
)
conformal_predictor.fit(self.X_calib, self.y_calib)

def test_get_nconf_scores_split(self):
# Conformal predictor
conformal_predictor = ConformalPredictor(
Expand Down

0 comments on commit 28ce8bf

Please sign in to comment.