diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
index a75da4f..24c8f8c 100644
--- a/.github/workflows/publish.yml
+++ b/.github/workflows/publish.yml
@@ -24,4 +24,4 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PUNCC_TO_PYPI }}
run: |
python setup.py sdist bdist_wheel
- twine upload dist/*
\ No newline at end of file
+ twine upload dist/*
diff --git a/README.md b/README.md
index 7015b21..b488130 100644
--- a/README.md
+++ b/README.md
@@ -94,45 +94,41 @@ Conformal prediction enables to transform point predictions into interval predic
-Many conformal prediction algorithms can easily be applied using *puncc*. The code snippet below shows the example of split conformal prediction wrapping a linear model, done in few lines of code:
+Many conformal prediction algorithms can easily be applied using *puncc*. The code snippet below shows the example of split conformal prediction with a pretrained linear model:
-```python
-from sklearn import linear_model
-from deel.puncc.api.prediction import BasePredictor
+ ```python
+ from deel.puncc.api.prediction import BasePredictor
+from deel.puncc.regression import SplitCP
-# Load training data and test data
+# Load calibration and test data
# ...
-# Instanciate a linear regression model
-# linear_model = ...
+# Pretrained regression model
+# trained_linear_model = ...
+# Wrap the model to enable interoperability with different ML libraries
+trained_predictor = BasePredictor(trained_linear_model)
-# Create a predictor to wrap the linear regression model defined earlier.
-# This enables interoperability with different ML libraries.
-# The argument `is_trained` is set to False to tell that the the linear model
-# needs to be trained before the calibration.
-lin_reg_predictor = BasePredictor(linear_model, is_trained=False)
+# Instanciate the split conformal wrapper for the linear model.
+# Train argument is set to False because we do not want to retrain the model
+split_cp = SplitCP(trained_predictor, train=False)
-# Instanciate the split cp wrapper around the linear predictor.
-split_cp = SplitCP(lin_reg_predictor)
+# With a calibration dataset, compute (and store) nonconformity scores
+split_cp.fit(X_calib=X_calib, y_calib=y_calib)
-# Fit model (as is_trained` is False) on the fit dataset and
-# compute the residuals on the calibration dataset.
-# The fit (resp. calibration) subset is randomly sampled from the training
-# data and constitutes 80% (resp. 20%) of it (fit_ratio = 80%).
-split_cp.fit(X_train, y_train, fit_ratio=.8)
-
-# The predict returns the output of the linear model y_pred and
-# the calibrated interval [y_pred_lower, y_pred_upper].
-y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=alpha)
+# Obtain the model's point prediction y_pred and prediction interval
+# PI = [y_pred_lower, y_pred_upper] for a target coverage of 90% (1-alpha).
+y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=0.1)
```
-The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)):
+The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)):
+
+
### More flexibility with the API
@@ -171,19 +167,20 @@ If you use our library for your work, please cite our paper:
}
```
-## 💻 Contributing
-
-Contributions are welcome! Feel free to report an issue or open a pull
-request. Take a look at our guidelines [here](CONTRIBUTING.md).
## 🙏 Acknowledgments
This project received funding from the French ”Investing for the Future – PIA3” program within the Artificial and Natural Intelligence Toulouse Institute (ANITI). The authors gratefully acknowledge the support of the DEEL project.
-## 👨💻 Creators
+## 👨💻 About the Developers
-[Mouhcine MENDIL](https://github.com/M-Mouhcine) initially developed this library as a research tool, with assistance from [Lucas MOSSINA](https://github.com/lmossina). We have recently welcomed [Joseba DALMAU](https://github.com/jdalch) to the team to help enhance **puncc** and work on the development of new features.
+Puncc's development team is a group of passionate scientists and engineers who are committed to developing a dependable and user-friendly open-source software. We are always looking for new contributors to this initiative. If you are interested in helping us develop puncc, please feel free to get involved.
+
+## 💻 Contributing
+
+Contributions are welcome! Feel free to report an issue or open a pull
+request. Take a look at our guidelines [here](CONTRIBUTING.md).
## 🔑 License
diff --git a/deel/puncc/api/conformalization.py b/deel/puncc/api/conformalization.py
index ebb9a18..12c8748 100644
--- a/deel/puncc/api/conformalization.py
+++ b/deel/puncc/api/conformalization.py
@@ -34,6 +34,7 @@
from deel.puncc.api.calibration import BaseCalibrator
from deel.puncc.api.calibration import CvPlusCalibrator
from deel.puncc.api.prediction import BasePredictor
+from deel.puncc.api.prediction import DualPredictor
from deel.puncc.api.splitting import BaseSplitter
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class ConformalPredictor:
:param deel.puncc.api.prediction.BasePredictor predictor: model wrapper.
:param deel.puncc.api.prediction.BaseCalibrator calibrator: nonconformity
- computation strategy and interval predictor.
+ computation strategy and set predictor.
:param deel.puncc.api.prediction.BaseSplitter splitter: fit/calibration
split strategy.
:param str method: method to handle the ensemble prediction and calibration
@@ -124,7 +125,27 @@ def __init__(
train: bool = True,
):
self.calibrator = calibrator
- self.predictor = predictor
+
+ if isinstance(predictor, (BasePredictor, DualPredictor)):
+ self.predictor = predictor
+
+ elif not hasattr(predictor, "fit"):
+ raise RuntimeError(
+ "Provided model has no fit method. "
+ + "Use a BasePredictor or a DualPredictor to build "
+ + "a compatible predictor."
+ )
+
+ elif not hasattr(predictor, "predict"):
+ raise RuntimeError(
+ "Provided model has no predict method. "
+ + "Use a BasePredictor or a DualPredictor to build "
+ + "a compatible predictor."
+ )
+
+ else:
+ self.predictor = BasePredictor(predictor, is_trained=not train)
+
self.splitter = splitter
self.method = method
self.train = train
@@ -217,8 +238,12 @@ def fit(
# Make local copies of the structure of the predictor and the calibrator.
# In case of a K-fold like splitting strategy, these structures are
# inherited by the predictor/calibrator used in each fold.
- predictor = self.predictor.copy()
- calibrator = deepcopy(self.calibrator)
+ if len(splits) > 1:
+ predictor = self.predictor.copy()
+ calibrator = deepcopy(self.calibrator)
+ else:
+ predictor = self.predictor
+ calibrator = self.calibrator
if self.train:
logger.info(f"Fitting model on fold {i+cached_len}")
@@ -290,7 +315,9 @@ def load(path):
with open(path, "rb") as input_file:
saved_dict = pickle.load(input_file)
- loaded_cp = ConformalPredictor(None, None, None)
+ loaded_cp = ConformalPredictor(
+ calibrator=None, predictor=BasePredictor(None), splitter=None
+ )
loaded_cp.__dict__.update(saved_dict)
return loaded_cp
diff --git a/deel/puncc/api/corrections.py b/deel/puncc/api/corrections.py
index d1c5aec..32c1968 100644
--- a/deel/puncc/api/corrections.py
+++ b/deel/puncc/api/corrections.py
@@ -29,7 +29,7 @@
def bonferroni(alpha: float, nvars: int) -> float:
"""Bonferroni correction for multiple comparisons.
-
+
:param float alpha: nominal coverage level.
:param int nvars: number of output features.
@@ -37,7 +37,7 @@ def bonferroni(alpha: float, nvars: int) -> float:
:returns: corrected coverage level.
:rtype: float.
"""
- # Sanity checks
+ # Sanity checks
if alpha <= 0 or alpha >= 1:
raise ValueError("alpha must be in (0,1)")
@@ -49,19 +49,19 @@ def bonferroni(alpha: float, nvars: int) -> float:
def weighted_bonferroni(alpha: float, weights: np.ndarray) -> np.ndarray:
"""Weighted Bonferroni correction for multiple comparisons.
-
+
:param float alpha: nominal coverage level.
:param np.ndarray weights: weights associated to each output feature.
-
-
+
+
:returns: array of corrected featurewise coverage levels.
:rtype: np.ndarray.
"""
- # Sanity checks
+ # Sanity checks
if alpha <= 0 or alpha >= 1:
raise ValueError("alpha must be in (0,1)")
- # Positivity check
+ # Positivity check
positiveness_condition = np.all(weights > 0)
if not positiveness_condition:
raise RuntimeError("weights must be positive")
diff --git a/deel/puncc/api/utils.py b/deel/puncc/api/utils.py
index 74cb6bd..c813726 100644
--- a/deel/puncc/api/utils.py
+++ b/deel/puncc/api/utils.py
@@ -145,7 +145,9 @@ def supported_types_check(*data: Iterable):
)
-def alpha_calib_check(alpha: Union[float, np.ndarray], n: int, complement_check: bool = False):
+def alpha_calib_check(
+ alpha: Union[float, np.ndarray], n: int, complement_check: bool = False
+):
"""Check if the value of alpha :math:`\\alpha` is consistent with the size
of calibration set :math:`n`.
@@ -181,7 +183,7 @@ def alpha_calib_check(alpha: Union[float, np.ndarray], n: int, complement_check:
if np.any(alpha >= 1):
raise ValueError(
- f"Alpha={alpha} is too large. "
+ f"Alpha={alpha} is too large. "
+ "Decrease alpha such that all of its coordinates are < 1."
)
diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py
index e85b190..9f1d49c 100644
--- a/tests/api/test_utils.py
+++ b/tests/api/test_utils.py
@@ -159,7 +159,6 @@ def test_too_low_alpha_multidim(self):
alpha_calib_check(alpha, n=10)
-
class QuantileCheck(unittest.TestCase):
def test_simple_quantile1d(self):
self.a_np = np.array([1, 2, 3, 4])