From a7dff8bec738b05c51cb9c9fd5992784a4486cab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Luis=20Garrido-Labrador?= Date: Wed, 8 May 2024 13:57:42 +0200 Subject: [PATCH] Add new functions for sslearn.restricted including Kuncheva et al. 2024 --- CHANGELOG.md | 5 +- sslearn/restricted.py | 177 +++++++++++++++++++++++++++++++++++++++++- test/test_general.py | 25 +++++- 3 files changed, 203 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a8037e..63b9fc9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [1.0.5] - XXXX-XX-XX +## [1.0.5] - 2024-05-08 + +### Added +- `feature_fusion` and `probability_fusion` methods for restricted in `sslearn.restricted` module. ### Fixed - CoForest random integer is now compatible with Windows. diff --git a/sslearn/restricted.py b/sslearn/restricted.py index 011be28..befb89c 100644 --- a/sslearn/restricted.py +++ b/sslearn/restricted.py @@ -21,9 +21,84 @@ from sklearn.base import ClassifierMixin, MetaEstimatorMixin, BaseEstimator from scipy.optimize import linear_sum_assignment import warnings +from sklearn.base import clone +from sklearn.feature_selection import VarianceThreshold +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from .base import get_dataset import pandas as pd +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis + +__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier", "feature_fusion", "probability_fusion"] + + +def feature_fusion(classifier, X, must_link, cannot_link): + """ + Restricted Set Classification for the instances with pairwise constraints. + Combine all instances that have the must-link constraint with the average of their features. + + **Parameters** + ---------- + classifier : ClassifierMixin with predict_proba method + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Array representing the data. + must_link : dict of {int: list of int} + Dictionary with the must links, where the key is the instance and the value is a list of instances that must have the same label. + cannot_link : dict of {int: list of int} + Dictionary with the cannot links, where the value is a list of instances that cannot have the same label. + + **Returns** + ---------- + y : ndarray of shape (n_samples,) + Array with predicted labels. + + **References** + ---------- + L.I. Kuncheva, J.L. Garrido-Labrador, I. Ramos-Pérez, S.L. Hennessey, J.J. Rodríguez (2024).
+ Semi-supervised classification with pairwise constraints: A case study on animal identification from video.
+ Information Fusion,
+ 104, 102188, [10.1016/j.inffus.2023.102188](https://doi.org/10.1016/j.inffus.2023.102188) + """ + + X_combined = __combine_features(X, must_link) + y_pred_proba = classifier.predict_proba(X_combined) + + return __restricted_set_classification(y_pred_proba, cannot_link, classifier.classes_) + + +def probability_fusion(classifier, X, must_link, cannot_link): + """ + Restricted Set Classification for the instances with pairwise constraints. + The class probability for each instance is defined as the mean of the probabilities reported by the classifier according to the must-link constraint. + + **Parameters** + ---------- + classifier : ClassifierMixin with predict_proba method + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Array representing the data. + must_link : dict of {int: list of int} + Dictionary with the must links, where the key is the instance and the value is a list of instances that must have the same label. + cannot_link : dict of {int: list of int} + Dictionary with the cannot links, where the value is a list of instances that cannot have the same label. + + **Returns** + ---------- + y : ndarray of shape (n_samples,) + Array with predicted labels. + + **References** + ---------- + L.I. Kuncheva, J.L. Garrido-Labrador, I. Ramos-Pérez, S.L. Hennessey, J.J. Rodríguez (2024).
+ Semi-supervised classification with pairwise constraints: A case study on animal identification from video.
+ Information Fusion,
+ 104, 102188, [10.1016/j.inffus.2023.102188](https://doi.org/10.1016/j.inffus.2023.102188) + """ + + y_probs = classifier.predict_proba(X) + classes = classifier.classes_ + y_probs_combined, _ = __combine_probabilities(y_probs, must_link, classes) + return __restricted_set_classification(y_probs_combined, cannot_link, classes) -__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier"] class WhoIsWhoClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): @@ -205,4 +280,102 @@ def _hungarian(probas_matrix): _, col_ind = linear_sum_assignment(costs, maximize=True) col_ind = list(col_ind) - return col_ind \ No newline at end of file + return col_ind + +def __combine_probabilities(y_probs, objects_in_track, classes): + """" + Averages the classifier probabilities of the instances in the same track. + + :param y_probs: classifier probabilities for the instances + :param objects_in_track: dictionary with the tracks + :param classes: classes used to train the classifier + + :return: a tuple with the modified y_probs ans the predicted classes + """ + + y_probs_combined = y_probs.copy() + + for objects in objects_in_track.values(): + if len(objects) <= 1: + continue + means = y_probs_combined[objects, :].mean(axis=0) + y_probs_combined[objects, :] = means + + preds = classes.take(list(np.argmax(y_probs_combined, axis=1))) + return y_probs_combined, preds + +def __combine_features(X, objects_in_track): + """ + Averages the features of the instances in the same track. + + :param X: feature values of the instances. + :param objects_in_track: dictionary with the tracks + + :return: a modified X with averaged features + """ + + X_combined = X.copy() + is_df = isinstance(X, pd.DataFrame) + if is_df: + X_combined = X.values + for objects in objects_in_track.values(): + if len(objects) <= 1: + continue + means = X_combined[objects].mean(axis=0) + X_combined[objects] = means + return X_combined + +def __restricted_set_classification(y_probs, instances_by_frame, classes): + """ + Restricted Set Classification for the instances in several frames + + :param y_probs: the probabilities given by the classifier for the instances + :param instances_by_frame: which instances are in each frame + :param classes: the classes seen by the classifier + + :return: the predicted labels + """ + + restricted_pred = [] + num_conflicts = 0 + for fr, group in instances_by_frame.items(): + if len(group) == 0: + continue + first, last = group[0], group[-1] + group_probs = y_probs[first:last + 1] + conflict, group_pred = __restricted_set_hungarian(group_probs, classes) + restricted_pred.extend(group_pred) + num_conflicts += conflict + + assert len(restricted_pred) == len(y_probs), "The number of predictions is different from the number of instances, check cannot link constraints, all instances must be in a cannot-link group." + + return restricted_pred + +def __restricted_set_hungarian(probs, classes): + """ + Restricted Set Classification for a set of objects that have to be of different classes + + :param probs: the probabilities given by the classifier + :param classes: the classes seen by the classifier + + :return: a tuple with 1) the Hungarian method was used (0 or 1), and 2) the predicted classes + """ + + rows, cols = probs.shape + preds = list(np.argmax(probs, axis=1)) + + if rows > cols or len(preds) == len(set(preds)): + # return 0 if rows > cols else 1, classes.take(preds) + return 0, classes.take(preds) + costs = np.log(probs) + + try: + row_ind, col_ind = linear_sum_assignment(costs, maximize=True) + col_ind = list(col_ind) + except: # some of the values was -Inf + probs += np.nextafter(0, 1) # small double value + costs = np.log(probs) + row_ind, col_ind = linear_sum_assignment(costs, maximize=True) + col_ind = list(col_ind) + + return 1, classes.take(col_ind) diff --git a/test/test_general.py b/test/test_general.py index ed94462..2c2fa94 100644 --- a/test/test_general.py +++ b/test/test_general.py @@ -15,7 +15,7 @@ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) from sslearn.base import FakedProbaClassifier, OneVsRestSSLClassifier from sslearn.restricted import (WhoIsWhoClassifier, combine_predictions, - conflict_rate) + conflict_rate, probability_fusion, feature_fusion) from sslearn.utils import (calc_number_per_class, calculate_prior_probability, check_n_jobs, choice_with_proportion, confidence_interval, is_int, safe_division) @@ -129,4 +129,27 @@ def test_WhoIsWhoClassifier(self): assert hyp.conflict_in_train == 1 assert hyp.predict(X, group).tolist() == [0, 1, 0, 2, 1] + def test_probability_fusion(self): + X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + y = np.array([0, 1, 0, 1, 2]) + cannot_link = {0: [0, 1], 1: [2, 3, 4]} + must_link = {1: [1, 3], 0: [0, 2], 4: [4]} + + h = GaussianNB() + h.fit(X, y) + + probability_fusion(h, X, must_link=must_link, cannot_link=cannot_link) + + def test_feature_fusion(self): + X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]) + y = np.array([0, 1, 0, 1, 2]) + cannot_link = {0: [0, 1], 1: [2, 3, 4]} + must_link = {1: [1, 3], 0: [0, 2], 4: [4]} + + h = GaussianNB() + h.fit(X, y) + + result = feature_fusion(h, X, must_link=must_link, cannot_link=cannot_link) + +