Skip to content

Commit

Permalink
Builtin predictors (#44)
Browse files Browse the repository at this point in the history
* feat: enable conformal procedure without api predictor wrappers for models that natively implements fit and predict.
  • Loading branch information
M-Mouhcine committed Dec 21, 2023
1 parent 80e8784 commit 70700e2
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ jobs:
TWINE_PASSWORD: ${{ secrets.PUNCC_TO_PYPI }}
run: |
python setup.py sdist bdist_wheel
twine upload dist/*
twine upload dist/*
59 changes: 28 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,45 +94,41 @@ Conformal prediction enables to transform point predictions into interval predic
<img src="docs/assets/cp_process.png"/>
</figure>

Many conformal prediction algorithms can easily be applied using *puncc*. The code snippet below shows the example of split conformal prediction wrapping a linear model, done in few lines of code:
Many conformal prediction algorithms can easily be applied using *puncc*. The code snippet below shows the example of split conformal prediction with a pretrained linear model:

```python
from sklearn import linear_model
from deel.puncc.api.prediction import BasePredictor
```python
from deel.puncc.api.prediction import BasePredictor
from deel.puncc.regression import SplitCP

# Load training data and test data
# Load calibration and test data
# ...

# Instanciate a linear regression model
# linear_model = ...
# Pretrained regression model
# trained_linear_model = ...

# Wrap the model to enable interoperability with different ML libraries
trained_predictor = BasePredictor(trained_linear_model)

# Create a predictor to wrap the linear regression model defined earlier.
# This enables interoperability with different ML libraries.
# The argument `is_trained` is set to False to tell that the the linear model
# needs to be trained before the calibration.
lin_reg_predictor = BasePredictor(linear_model, is_trained=False)
# Instanciate the split conformal wrapper for the linear model.
# Train argument is set to False because we do not want to retrain the model
split_cp = SplitCP(trained_predictor, train=False)

# Instanciate the split cp wrapper around the linear predictor.
split_cp = SplitCP(lin_reg_predictor)
# With a calibration dataset, compute (and store) nonconformity scores
split_cp.fit(X_calib=X_calib, y_calib=y_calib)

# Fit model (as is_trained` is False) on the fit dataset and
# compute the residuals on the calibration dataset.
# The fit (resp. calibration) subset is randomly sampled from the training
# data and constitutes 80% (resp. 20%) of it (fit_ratio = 80%).
split_cp.fit(X_train, y_train, fit_ratio=.8)

# The predict returns the output of the linear model y_pred and
# the calibrated interval [y_pred_lower, y_pred_upper].
y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=alpha)
# Obtain the model's point prediction y_pred and prediction interval
# PI = [y_pred_lower, y_pred_upper] for a target coverage of 90% (1-alpha).
y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=0.1)
```

The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)):

The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)):
<div align="center">
<figure style="text-align:center">
<img src="docs/assets/results_quickstart_split_cp_pi.png" alt="90% Prediction Interval with the Split Conformal Prediction Method"/>
<img src="docs/assets/results_quickstart_split_cp_pi.png" alt="90% Prediction Interval with the Split Conformal Prediction Method" width="70%"/>
<div align=center>90% Prediction Interval with Split Conformal Prediction.</div>
</figure>
</div>
<br>

### More flexibility with the API
Expand Down Expand Up @@ -171,19 +167,20 @@ If you use our library for your work, please cite our paper:
}
```

## 💻 Contributing

Contributions are welcome! Feel free to report an issue or open a pull
request. Take a look at our guidelines [here](CONTRIBUTING.md).

## 🙏 Acknowledgments

<img align="right" src="https://www.deel.ai/wp-content/uploads/2021/05/logo-DEEL.png" width="25%">
This project received funding from the French ”Investing for the Future – PIA3” program within the Artificial and Natural Intelligence Toulouse Institute (ANITI). The authors gratefully acknowledge the support of the <a href="https://www.deel.ai/"> DEEL </a> project.

## 👨‍💻 Creators
## 👨‍💻 About the Developers

[Mouhcine MENDIL](https://github.com/M-Mouhcine) initially developed this library as a research tool, with assistance from [Lucas MOSSINA](https://github.com/lmossina). We have recently welcomed [Joseba DALMAU](https://github.com/jdalch) to the team to help enhance **puncc** and work on the development of new features.
Puncc's development team is a group of passionate scientists and engineers who are committed to developing a dependable and user-friendly open-source software. We are always looking for new contributors to this initiative. If you are interested in helping us develop puncc, please feel free to get involved.

## 💻 Contributing

Contributions are welcome! Feel free to report an issue or open a pull
request. Take a look at our guidelines [here](CONTRIBUTING.md).

## 🔑 License

Expand Down
37 changes: 32 additions & 5 deletions deel/puncc/api/conformalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from deel.puncc.api.calibration import BaseCalibrator
from deel.puncc.api.calibration import CvPlusCalibrator
from deel.puncc.api.prediction import BasePredictor
from deel.puncc.api.prediction import DualPredictor
from deel.puncc.api.splitting import BaseSplitter

logger = logging.getLogger(__name__)
Expand All @@ -44,7 +45,7 @@ class ConformalPredictor:
:param deel.puncc.api.prediction.BasePredictor predictor: model wrapper.
:param deel.puncc.api.prediction.BaseCalibrator calibrator: nonconformity
computation strategy and interval predictor.
computation strategy and set predictor.
:param deel.puncc.api.prediction.BaseSplitter splitter: fit/calibration
split strategy.
:param str method: method to handle the ensemble prediction and calibration
Expand Down Expand Up @@ -124,7 +125,27 @@ def __init__(
train: bool = True,
):
self.calibrator = calibrator
self.predictor = predictor

if isinstance(predictor, (BasePredictor, DualPredictor)):
self.predictor = predictor

elif not hasattr(predictor, "fit"):
raise RuntimeError(
"Provided model has no fit method. "
+ "Use a BasePredictor or a DualPredictor to build "
+ "a compatible predictor."
)

elif not hasattr(predictor, "predict"):
raise RuntimeError(
"Provided model has no predict method. "
+ "Use a BasePredictor or a DualPredictor to build "
+ "a compatible predictor."
)

else:
self.predictor = BasePredictor(predictor, is_trained=not train)

self.splitter = splitter
self.method = method
self.train = train
Expand Down Expand Up @@ -217,8 +238,12 @@ def fit(
# Make local copies of the structure of the predictor and the calibrator.
# In case of a K-fold like splitting strategy, these structures are
# inherited by the predictor/calibrator used in each fold.
predictor = self.predictor.copy()
calibrator = deepcopy(self.calibrator)
if len(splits) > 1:
predictor = self.predictor.copy()
calibrator = deepcopy(self.calibrator)
else:
predictor = self.predictor
calibrator = self.calibrator

if self.train:
logger.info(f"Fitting model on fold {i+cached_len}")
Expand Down Expand Up @@ -290,7 +315,9 @@ def load(path):
with open(path, "rb") as input_file:
saved_dict = pickle.load(input_file)

loaded_cp = ConformalPredictor(None, None, None)
loaded_cp = ConformalPredictor(
calibrator=None, predictor=BasePredictor(None), splitter=None
)
loaded_cp.__dict__.update(saved_dict)
return loaded_cp

Expand Down
14 changes: 7 additions & 7 deletions deel/puncc/api/corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@

def bonferroni(alpha: float, nvars: int) -> float:
"""Bonferroni correction for multiple comparisons.
:param float alpha: nominal coverage level.
:param int nvars: number of output features.
:returns: corrected coverage level.
:rtype: float.
"""
# Sanity checks
# Sanity checks
if alpha <= 0 or alpha >= 1:
raise ValueError("alpha must be in (0,1)")

Expand All @@ -49,19 +49,19 @@ def bonferroni(alpha: float, nvars: int) -> float:

def weighted_bonferroni(alpha: float, weights: np.ndarray) -> np.ndarray:
"""Weighted Bonferroni correction for multiple comparisons.
:param float alpha: nominal coverage level.
:param np.ndarray weights: weights associated to each output feature.
:returns: array of corrected featurewise coverage levels.
:rtype: np.ndarray.
"""
# Sanity checks
# Sanity checks
if alpha <= 0 or alpha >= 1:
raise ValueError("alpha must be in (0,1)")

# Positivity check
# Positivity check
positiveness_condition = np.all(weights > 0)
if not positiveness_condition:
raise RuntimeError("weights must be positive")
Expand Down
6 changes: 4 additions & 2 deletions deel/puncc/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def supported_types_check(*data: Iterable):
)


def alpha_calib_check(alpha: Union[float, np.ndarray], n: int, complement_check: bool = False):
def alpha_calib_check(
alpha: Union[float, np.ndarray], n: int, complement_check: bool = False
):
"""Check if the value of alpha :math:`\\alpha` is consistent with the size
of calibration set :math:`n`.
Expand Down Expand Up @@ -181,7 +183,7 @@ def alpha_calib_check(alpha: Union[float, np.ndarray], n: int, complement_check:

if np.any(alpha >= 1):
raise ValueError(
f"Alpha={alpha} is too large. "
f"Alpha={alpha} is too large. "
+ "Decrease alpha such that all of its coordinates are < 1."
)

Expand Down
1 change: 0 additions & 1 deletion tests/api/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def test_too_low_alpha_multidim(self):
alpha_calib_check(alpha, n=10)



class QuantileCheck(unittest.TestCase):
def test_simple_quantile1d(self):
self.a_np = np.array([1, 2, 3, 4])
Expand Down

0 comments on commit 70700e2

Please sign in to comment.