Skip to content

Commit

Permalink
Conformal Classification with LAC (#60)
Browse files Browse the repository at this point in the history
* Add LAC classification method

* correct lac implementation

* Correct LAC implementation

* add lac test

* correct lac test

* Add LAC to theory overview

* Correct linting errors

* doc: add theory overview of CAD and update readme to include implemented CP methods

---------

Co-authored-by: M-Mouhcine <[email protected]>
  • Loading branch information
jdalch and M-Mouhcine authored Nov 5, 2024
1 parent 2babf1f commit 8c00b6c
Show file tree
Hide file tree
Showing 11 changed files with 436 additions and 69 deletions.
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ y_pred, y_pred_lower, y_pred_upper = split_cp.predict(X_test, alpha=0.1)
```


The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [Introduction tutorial](docs/puncc_intro.ipynb)):
The library provides several metrics (`deel.puncc.metrics`) and plotting capabilities (`deel.puncc.plotting`) to evaluate and visualize the results of a conformal procedure. For a target error rate of $\alpha = 0.1$, the marginal coverage reached in this example on the test set is higher than $90$% (see [**Introduction tutorial**](docs/puncc_intro.ipynb)):
<div align="center">
<figure style="text-align:center">
<img src="docs/assets/results_quickstart_split_cp_pi.png" alt="90% Prediction Interval with the Split Conformal Prediction Method" width="70%"/>
Expand All @@ -138,7 +138,31 @@ The library provides several metrics (`deel.puncc.metrics`) and plotting capabil
- A direct approach to run state-of-the-art conformal prediction procedures. This is what we used in the previous conformal regression example.
- **Low-level API**: a more flexible approach based of full customization of the prediction model, the choice of nonconformity scores and the split between fit and calibration datasets.

A quick comparison of both approaches is provided in the [API tutorial](docs/api_intro.ipynb) for a regression problem.
A quick comparison of both approaches is provided in the [**API tutorial**](docs/api_intro.ipynb) for a regression problem.

<figure style="text-align:center">
<img src="docs/assets/puncc_architecture.png" width="100%"/>
</figure>

### 🖥️ Implemented Algorithms
<details>
<summary>Overview of Implemented Methods from the Literature:</summary>

| Procedure Type | Procedure Name | Description (more details in [Theory overview](https://deel-ai.github.io/puncc/theory_overview.html)) |
|-----------------------------------------|------------------------------------------------------|-------------------------------------------------------|
| Conformal Regression | [`deel.puncc.regression.SplitCP`](https://deel-ai.github.io/puncc/regression.html#deel.puncc.regression.SplitCP) | Split Conformal Regression |
| Conformal Regression | [`deel.puncc.regression.LocallyAdaptiveCP`](https://deel-ai.github.io/puncc/regression.html#deel.puncc.regression.LocallyAdaptiveCP) | Locally Adaptive Conformal Regression |
| Conformal Regression | [`deel.puncc.regression.CQR`](https://deel-ai.github.io/puncc/regression.html#deel.puncc.regression.CQR) | Conformalized Quantile Regression |
| Conformal Regression | [`deel.puncc.regression.CvPlus`](https://deel-ai.github.io/puncc/regression.html#deel.puncc.regression.CVPlus) | CV + (cross-validation) |
| Conformal Regression | [`deel.puncc.regression.EnbPI`](https://deel-ai.github.io/puncc/regression.html#deel.puncc.regression.EnbPI) | Ensemble Batch Prediction Intervals method |
| Conformal Regression | [`deel.puncc.regression.aEnbPI`](https://deel-ai.github.io/puncc/regression.html#deel.puncc.regression.AdaptiveEnbPI) | Locally adaptive Ensemble Batch Prediction Intervals method |
| Conformal Classification | [`deel.puncc.classification.LAC`](https://deel-ai.github.io/puncc/classification.html#deel.puncc.classification.LAC) | Least Ambiguous Set-Valued Classifiers |
| Conformal Classification | [`deel.puncc.classification.APS`](https://deel-ai.github.io/puncc/classification.html#deel.puncc.classification.APS) | Adaptive Prediction Sets |
| Conformal Classification | [`deel.puncc.classification.RAPS`](https://deel-ai.github.io/puncc/classification.html#deel.puncc.classification.RAPS) | Regularized Adaptive Prediction Sets (APS is a special case where $\lambda = 0$) |
| Conformal Anomaly Detection | [`deel.puncc.anomaly_detection.SplitCAD`](https://deel-ai.github.io/puncc/anomaly_detection.html#deel.puncc.anomaly_detection.SplitCAD) | Split Conformal Anomaly detection (used to control the maximum false positive rate) |
| Conformal Object Detection | [`deel.puncc.object_detection.SplitBoxWise`](https://deel-ai.github.io/puncc/object_detection.html#deel.puncc.object_detection.SplitBoxWise) | Box-wise split conformal object detection |

</details>

## 📚 Citation

Expand Down
30 changes: 30 additions & 0 deletions deel/puncc/api/nonconformity_scores.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Classification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~


def lac_score(
Y_pred: Iterable,
y_true: Iterable,
) -> Iterable:
"""LAC nonconformity score.
:param Iterable Y_pred:
:math:`Y_{\\text{pred}} = (P_{\\text{C}_1}, ..., P_{\\text{C}_n})`
where :math:`P_{\\text{C}_i}` is logit associated to class i.
:param Iterable y_true: true labels.
:returns: RAPS nonconformity scores.
:rtype: Iterable
:raises TypeError: unsupported data types.
"""
supported_types_check(Y_pred, y_true)

# Check if logits sum is close to one
logit_normalization_check(Y_pred)

if not isinstance(Y_pred, np.ndarray):
raise NotImplementedError(
"LAC nonconformity score only implemented for ndarrays"
)

# Compute and return the LAC nonconformity score
return 1 - Y_pred[np.arange(y_true.shape[0]), y_true]


def raps_score(
Y_pred: Iterable,
y_true: Iterable,
Expand Down
32 changes: 32 additions & 0 deletions deel/puncc/api/prediction_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,38 @@
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Classification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~


def lac_set(
Y_pred, scores_quantile
) -> List:
"""LAC prediction set.
:param Iterable Y_pred:
:math:`Y_{\\text{pred}} = (P_{\\text{C}_1}, ..., P_{\\text{C}_n})`
where :math:`P_{\\text{C}_i}` is logit associated to class i.
:param ndarray scores_quantile: quantile of nonconformity scores computed
on a calibration set for a given :math:`\\alpha`
:returns: LAC prediction sets.
:rtype: Iterable
"""
# Check if logits sum is close to one
logit_normalization_check(Y_pred)

pred_len = len(Y_pred)

logger.debug(f"Shape of Y_pred: {Y_pred.shape}")

# Build prediction sets
prediction_sets = [
np.where(Y_pred[i] >= 1 - scores_quantile)[0].tolist() for i in range(pred_len)
]

return (prediction_sets,)


def raps_set(
Y_pred, scores_quantile, lambd: float = 0, k_reg: int = 1, rand: bool = True
) -> List:
Expand Down
133 changes: 130 additions & 3 deletions deel/puncc/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,133 @@
from deel.puncc.api.prediction import BasePredictor
from deel.puncc.api.splitting import IdSplitter
from deel.puncc.api.splitting import RandomSplitter
from deel.puncc.regression import SplitCP


class LAC(SplitCP):
"""Implementation of the Least Ambiguous Set-Valued Classifier (LAC).
For more details, we refer the user to the
:ref:`theory overview page <theory lac>`.
:param BasePredictor predictor: a predictor implementing fit and predict.
:param bool train: if False, prediction model(s) will not be trained and
will be used as is. Defaults to True.
.. _example lac:
Example::
from deel.puncc.classification import LAC
from deel.puncc.api.prediction import BasePredictor
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from deel.puncc.metrics import classification_mean_coverage
from deel.puncc.metrics import classification_mean_size
import numpy as np
from tensorflow.keras.utils import to_categorical
# Generate a random regression problem
X, y = make_classification(n_samples=1000, n_features=4, n_informative=2,
n_classes = 2,random_state=0, shuffle=False)
# Split data into train and test
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=.2, random_state=0
)
# Split train data into fit and calibration
X_fit, X_calib, y_fit, y_calib = train_test_split(
X_train, y_train, test_size=.2, random_state=0
)
# One hot encoding of classes
y_fit_cat = to_categorical(y_fit)
y_calib_cat = to_categorical(y_calib)
y_test_cat = to_categorical(y_test)
# Create rf classifier
rf_model = RandomForestClassifier(n_estimators=100, random_state=0)
# Create a wrapper of the random forest model to redefine its predict method
# into logits predictions. Make sure to subclass BasePredictor.
# Note that we needed to build a new wrapper (over BasePredictor) only because
# the predict(.) method of RandomForestClassifier does not predict logits.
# Otherwise, it is enough to use BasePredictor (e.g., neural network with softmax).
class RFPredictor(BasePredictor):
def predict(self, X, **kwargs):
return self.model.predict_proba(X, **kwargs)
# Wrap model in the newly created RFPredictor
rf_predictor = RFPredictor(rf_model)
# CP method initialization
lac_cp = LAC(rf_predictor)
# The call to `fit` trains the model and computes the nonconformity
# scores on the calibration set
lac_cp.fit(X_fit=X_fit, y_fit=y_fit, X_calib=X_calib, y_calib=y_calib)
# The predict method infers prediction sets with respect to
# the significance level alpha = 20%
y_pred, set_pred = lac_cp.predict(X_test, alpha=.2)
# Compute marginal coverage
coverage = classification_mean_coverage(y_test, set_pred)
size = classification_mean_size(set_pred)
print(f"Marginal coverage: {np.round(coverage, 2)}")
print(f"Average prediction set size: {np.round(size, 2)}")
"""

def __init__(
self,
predictor: Union[BasePredictor, Any],
train: bool = True,
random_state: float = None,
):
super().__init__(
predictor=predictor,
train=train,
random_state=random_state,
)
self.calibrator = BaseCalibrator(
nonconf_score_func=nonconformity_scores.lac_score,
pred_set_func=prediction_sets.lac_set,
weight_func=None,
)
self.conformal_predictor = ConformalPredictor(
predictor=self.predictor,
calibrator=self.calibrator,
splitter=object(),
train=self.train,
)

def predict(self, X_test: Iterable, alpha: float) -> Tuple:
"""Conformal set predictions (w.r.t target miscoverage alpha)
for new samples.
:param Iterable X_test: features of new samples.
:param float alpha: target maximum miscoverage.
:returns: Tuple composed of the model estimate y_pred and the
prediction set set_pred
:rtype: Tuple
"""

if self.conformal_predictor is None:
raise RuntimeError("Fit method should be called before predict.")

(y_pred, set_pred) = self.conformal_predictor.predict(
X_test, alpha=alpha
)

return y_pred, set_pred


class RAPS:
Expand Down Expand Up @@ -128,7 +255,7 @@ def predict(self, X, **kwargs):
raps_cp.fit(X_fit=X_fit, y_fit=y_fit, X_calib=X_calib, y_calib=y_calib)
# The predict method infers prediction intervals with respect to
# The predict method infers prediction sets with respect to
# the significance level alpha = 20%
y_pred, set_pred = raps_cp.predict(X_test, alpha=.2)
Expand Down Expand Up @@ -252,7 +379,7 @@ def fit(
self.conformal_predictor.fit(X=X, y=y, **kwargs)

def predict(self, X_test: Iterable, alpha: float) -> Tuple:
"""Conformal interval predictions (w.r.t target miscoverage alpha)
"""Conformal set predictions (w.r.t target miscoverage alpha)
for new samples.
:param Iterable X_test: features of new samples.
Expand Down Expand Up @@ -344,7 +471,7 @@ def predict(self, X, **kwargs):
# scores on the calibration set
aps_cp.(X_fit=X_fit, y_fit=y_fit, X_calib=X_calib, y_calib=y_calib)
# The predict method infers prediction intervals with respect to
# The predict method infers prediction sets with respect to
# the significance level alpha = 20%
y_pred, set_pred = aps_cp.predict(X_test, alpha=.2)
Expand Down
Binary file modified docs/assets/puncc_architecture.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 8c00b6c

Please sign in to comment.