diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a75da4f..24c8f8c 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -24,4 +24,4 @@ jobs: TWINE_PASSWORD: ${{ secrets.PUNCC_TO_PYPI }} run: | python setup.py sdist bdist_wheel - twine upload dist/* \ No newline at end of file + twine upload dist/* diff --git a/README.md b/README.md index 7015b21..b488130 100644 --- a/README.md +++ b/README.md @@ -94,45 +94,41 @@ Conformal prediction enables to transform point predictions into interval predic -Many conformal prediction algorithms can easily be applied using *puncc*. The code snippet below shows the example of split conformal prediction wrapping a linear model, done in few lines of code: +Many conformal prediction algorithms can easily be applied using *puncc*. The code snippet below shows the example of split conformal prediction with a pretrained linear model: -```python -from sklearn import linear_model -from deel.puncc.api.prediction import BasePredictor + ```python + from deel.puncc.api.prediction import BasePredictor +from deel.puncc.regression import SplitCP -# Load training data and test data +# Load calibration and test data # ... -# Instanciate a linear regression model -# linear_model = ... +# Pretrained regression model +# trained_linear_model = ... +# Wrap the model to enable interoperability with different ML libraries +trained_predictor = BasePredictor(trained_linear_model) -# Create a predictor to wrap the linear regression model defined earlier. -# This enables interoperability with different ML libraries. -# The argument `is_trained` is set to False to tell that the the linear model -# needs to be trained before the calibration. -lin_reg_predictor = BasePredictor(linear_model, is_trained=False) +# Instanciate the split conformal wrapper for the linear model. +# Train argument is set to False because we do not want to retrain the model +split_cp = SplitCP(trained_predictor, train=False) -# Instanciate the split cp wrapper around the linear predictor. -split_cp = SplitCP(lin_reg_predictor) +# With a calibration dataset, compute (and store) nonconformity scores +split_cp.fit(X_calib=X_calib, y_calib=y_calib) -# Fit model (as is_trained` is False) on the fit dataset and -# compute the residuals on the calibration dataset. -# The fit (resp. calibration) subset is randomly sampled from the training -# data and constitutes 80% (resp. 20%) of it (fit_ratio = 80%). -split_cp.fit(X_train, y_train, fit_ratio=.8) - -# The predict returns the output of the linear model y_pred and -# the calibrated interval [y_pred_lower, y_pred_upper]. -y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=alpha) +# Obtain the model's point prediction y_pred and prediction interval +# PI = [y_pred_lower, y_pred_upper] for a target coverage of 90% (1-alpha). +y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=0.1) ``` -The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)): +The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)): +
-90% Prediction Interval with the Split Conformal Prediction Method +90% Prediction Interval with the Split Conformal Prediction Method
90% Prediction Interval with Split Conformal Prediction.
+

### More flexibility with the API @@ -171,19 +167,20 @@ If you use our library for your work, please cite our paper: } ``` -## 💻 Contributing - -Contributions are welcome! Feel free to report an issue or open a pull -request. Take a look at our guidelines [here](CONTRIBUTING.md). ## 🙏 Acknowledgments This project received funding from the French ”Investing for the Future – PIA3” program within the Artificial and Natural Intelligence Toulouse Institute (ANITI). The authors gratefully acknowledge the support of the DEEL project. -## 👨‍💻 Creators +## 👨‍💻 About the Developers -[Mouhcine MENDIL](https://github.com/M-Mouhcine) initially developed this library as a research tool, with assistance from [Lucas MOSSINA](https://github.com/lmossina). We have recently welcomed [Joseba DALMAU](https://github.com/jdalch) to the team to help enhance **puncc** and work on the development of new features. +Puncc's development team is a group of passionate scientists and engineers who are committed to developing a dependable and user-friendly open-source software. We are always looking for new contributors to this initiative. If you are interested in helping us develop puncc, please feel free to get involved. + +## 💻 Contributing + +Contributions are welcome! Feel free to report an issue or open a pull +request. Take a look at our guidelines [here](CONTRIBUTING.md). ## 🔑 License diff --git a/deel/puncc/api/conformalization.py b/deel/puncc/api/conformalization.py index ebb9a18..12c8748 100644 --- a/deel/puncc/api/conformalization.py +++ b/deel/puncc/api/conformalization.py @@ -34,6 +34,7 @@ from deel.puncc.api.calibration import BaseCalibrator from deel.puncc.api.calibration import CvPlusCalibrator from deel.puncc.api.prediction import BasePredictor +from deel.puncc.api.prediction import DualPredictor from deel.puncc.api.splitting import BaseSplitter logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class ConformalPredictor: :param deel.puncc.api.prediction.BasePredictor predictor: model wrapper. :param deel.puncc.api.prediction.BaseCalibrator calibrator: nonconformity - computation strategy and interval predictor. + computation strategy and set predictor. :param deel.puncc.api.prediction.BaseSplitter splitter: fit/calibration split strategy. :param str method: method to handle the ensemble prediction and calibration @@ -124,7 +125,27 @@ def __init__( train: bool = True, ): self.calibrator = calibrator - self.predictor = predictor + + if isinstance(predictor, (BasePredictor, DualPredictor)): + self.predictor = predictor + + elif not hasattr(predictor, "fit"): + raise RuntimeError( + "Provided model has no fit method. " + + "Use a BasePredictor or a DualPredictor to build " + + "a compatible predictor." + ) + + elif not hasattr(predictor, "predict"): + raise RuntimeError( + "Provided model has no predict method. " + + "Use a BasePredictor or a DualPredictor to build " + + "a compatible predictor." + ) + + else: + self.predictor = BasePredictor(predictor, is_trained=not train) + self.splitter = splitter self.method = method self.train = train @@ -217,8 +238,12 @@ def fit( # Make local copies of the structure of the predictor and the calibrator. # In case of a K-fold like splitting strategy, these structures are # inherited by the predictor/calibrator used in each fold. - predictor = self.predictor.copy() - calibrator = deepcopy(self.calibrator) + if len(splits) > 1: + predictor = self.predictor.copy() + calibrator = deepcopy(self.calibrator) + else: + predictor = self.predictor + calibrator = self.calibrator if self.train: logger.info(f"Fitting model on fold {i+cached_len}") @@ -290,7 +315,9 @@ def load(path): with open(path, "rb") as input_file: saved_dict = pickle.load(input_file) - loaded_cp = ConformalPredictor(None, None, None) + loaded_cp = ConformalPredictor( + calibrator=None, predictor=BasePredictor(None), splitter=None + ) loaded_cp.__dict__.update(saved_dict) return loaded_cp diff --git a/deel/puncc/api/corrections.py b/deel/puncc/api/corrections.py index d1c5aec..32c1968 100644 --- a/deel/puncc/api/corrections.py +++ b/deel/puncc/api/corrections.py @@ -29,7 +29,7 @@ def bonferroni(alpha: float, nvars: int) -> float: """Bonferroni correction for multiple comparisons. - + :param float alpha: nominal coverage level. :param int nvars: number of output features. @@ -37,7 +37,7 @@ def bonferroni(alpha: float, nvars: int) -> float: :returns: corrected coverage level. :rtype: float. """ - # Sanity checks + # Sanity checks if alpha <= 0 or alpha >= 1: raise ValueError("alpha must be in (0,1)") @@ -49,19 +49,19 @@ def bonferroni(alpha: float, nvars: int) -> float: def weighted_bonferroni(alpha: float, weights: np.ndarray) -> np.ndarray: """Weighted Bonferroni correction for multiple comparisons. - + :param float alpha: nominal coverage level. :param np.ndarray weights: weights associated to each output feature. - - + + :returns: array of corrected featurewise coverage levels. :rtype: np.ndarray. """ - # Sanity checks + # Sanity checks if alpha <= 0 or alpha >= 1: raise ValueError("alpha must be in (0,1)") - # Positivity check + # Positivity check positiveness_condition = np.all(weights > 0) if not positiveness_condition: raise RuntimeError("weights must be positive") diff --git a/deel/puncc/api/utils.py b/deel/puncc/api/utils.py index 74cb6bd..c813726 100644 --- a/deel/puncc/api/utils.py +++ b/deel/puncc/api/utils.py @@ -145,7 +145,9 @@ def supported_types_check(*data: Iterable): ) -def alpha_calib_check(alpha: Union[float, np.ndarray], n: int, complement_check: bool = False): +def alpha_calib_check( + alpha: Union[float, np.ndarray], n: int, complement_check: bool = False +): """Check if the value of alpha :math:`\\alpha` is consistent with the size of calibration set :math:`n`. @@ -181,7 +183,7 @@ def alpha_calib_check(alpha: Union[float, np.ndarray], n: int, complement_check: if np.any(alpha >= 1): raise ValueError( - f"Alpha={alpha} is too large. " + f"Alpha={alpha} is too large. " + "Decrease alpha such that all of its coordinates are < 1." ) diff --git a/tests/api/test_utils.py b/tests/api/test_utils.py index e85b190..9f1d49c 100644 --- a/tests/api/test_utils.py +++ b/tests/api/test_utils.py @@ -159,7 +159,6 @@ def test_too_low_alpha_multidim(self): alpha_calib_check(alpha, n=10) - class QuantileCheck(unittest.TestCase): def test_simple_quantile1d(self): self.a_np = np.array([1, 2, 3, 4])