Skip to content

Commit

Permalink
Add new functions for sslearn.restricted including Kuncheva et al. 2024
Browse files Browse the repository at this point in the history
  • Loading branch information
jlgarridol committed May 8, 2024
1 parent 858bbf6 commit a7dff8b
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 4 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.0.5] - XXXX-XX-XX
## [1.0.5] - 2024-05-08

### Added
- `feature_fusion` and `probability_fusion` methods for restricted in `sslearn.restricted` module.

### Fixed
- CoForest random integer is now compatible with Windows.
Expand Down
177 changes: 175 additions & 2 deletions sslearn/restricted.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,84 @@
from sklearn.base import ClassifierMixin, MetaEstimatorMixin, BaseEstimator
from scipy.optimize import linear_sum_assignment
import warnings
from sklearn.base import clone
from sklearn.feature_selection import VarianceThreshold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from .base import get_dataset
import pandas as pd
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier", "feature_fusion", "probability_fusion"]


def feature_fusion(classifier, X, must_link, cannot_link):
"""
Restricted Set Classification for the instances with pairwise constraints.
Combine all instances that have the must-link constraint with the average of their features.
**Parameters**
----------
classifier : ClassifierMixin with predict_proba method
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Array representing the data.
must_link : dict of {int: list of int}
Dictionary with the must links, where the key is the instance and the value is a list of instances that must have the same label.
cannot_link : dict of {int: list of int}
Dictionary with the cannot links, where the value is a list of instances that cannot have the same label.
**Returns**
----------
y : ndarray of shape (n_samples,)
Array with predicted labels.
**References**
----------
L.I. Kuncheva, J.L. Garrido-Labrador, I. Ramos-Pérez, S.L. Hennessey, J.J. Rodríguez (2024).<br>
Semi-supervised classification with pairwise constraints: A case study on animal identification from video.<br>
<i>Information Fusion,</i><br>
104, 102188, [10.1016/j.inffus.2023.102188](https://doi.org/10.1016/j.inffus.2023.102188)
"""

X_combined = __combine_features(X, must_link)
y_pred_proba = classifier.predict_proba(X_combined)

return __restricted_set_classification(y_pred_proba, cannot_link, classifier.classes_)


def probability_fusion(classifier, X, must_link, cannot_link):
"""
Restricted Set Classification for the instances with pairwise constraints.
The class probability for each instance is defined as the mean of the probabilities reported by the classifier according to the must-link constraint.
**Parameters**
----------
classifier : ClassifierMixin with predict_proba method
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Array representing the data.
must_link : dict of {int: list of int}
Dictionary with the must links, where the key is the instance and the value is a list of instances that must have the same label.
cannot_link : dict of {int: list of int}
Dictionary with the cannot links, where the value is a list of instances that cannot have the same label.
**Returns**
----------
y : ndarray of shape (n_samples,)
Array with predicted labels.
**References**
----------
L.I. Kuncheva, J.L. Garrido-Labrador, I. Ramos-Pérez, S.L. Hennessey, J.J. Rodríguez (2024).<br>
Semi-supervised classification with pairwise constraints: A case study on animal identification from video.<br>
<i>Information Fusion,</i><br>
104, 102188, [10.1016/j.inffus.2023.102188](https://doi.org/10.1016/j.inffus.2023.102188)
"""

y_probs = classifier.predict_proba(X)
classes = classifier.classes_
y_probs_combined, _ = __combine_probabilities(y_probs, must_link, classes)
return __restricted_set_classification(y_probs_combined, cannot_link, classes)

__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier"]

class WhoIsWhoClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):

Expand Down Expand Up @@ -205,4 +280,102 @@ def _hungarian(probas_matrix):
_, col_ind = linear_sum_assignment(costs, maximize=True)
col_ind = list(col_ind)

return col_ind
return col_ind

def __combine_probabilities(y_probs, objects_in_track, classes):
""""
Averages the classifier probabilities of the instances in the same track.
:param y_probs: classifier probabilities for the instances
:param objects_in_track: dictionary with the tracks
:param classes: classes used to train the classifier
:return: a tuple with the modified y_probs ans the predicted classes
"""

y_probs_combined = y_probs.copy()

for objects in objects_in_track.values():
if len(objects) <= 1:
continue
means = y_probs_combined[objects, :].mean(axis=0)
y_probs_combined[objects, :] = means

preds = classes.take(list(np.argmax(y_probs_combined, axis=1)))
return y_probs_combined, preds

def __combine_features(X, objects_in_track):
"""
Averages the features of the instances in the same track.
:param X: feature values of the instances.
:param objects_in_track: dictionary with the tracks
:return: a modified X with averaged features
"""

X_combined = X.copy()
is_df = isinstance(X, pd.DataFrame)
if is_df:
X_combined = X.values
for objects in objects_in_track.values():
if len(objects) <= 1:
continue
means = X_combined[objects].mean(axis=0)
X_combined[objects] = means
return X_combined

def __restricted_set_classification(y_probs, instances_by_frame, classes):
"""
Restricted Set Classification for the instances in several frames
:param y_probs: the probabilities given by the classifier for the instances
:param instances_by_frame: which instances are in each frame
:param classes: the classes seen by the classifier
:return: the predicted labels
"""

restricted_pred = []
num_conflicts = 0
for fr, group in instances_by_frame.items():
if len(group) == 0:
continue
first, last = group[0], group[-1]
group_probs = y_probs[first:last + 1]
conflict, group_pred = __restricted_set_hungarian(group_probs, classes)
restricted_pred.extend(group_pred)
num_conflicts += conflict

assert len(restricted_pred) == len(y_probs), "The number of predictions is different from the number of instances, check cannot link constraints, all instances must be in a cannot-link group."

return restricted_pred

def __restricted_set_hungarian(probs, classes):
"""
Restricted Set Classification for a set of objects that have to be of different classes
:param probs: the probabilities given by the classifier
:param classes: the classes seen by the classifier
:return: a tuple with 1) the Hungarian method was used (0 or 1), and 2) the predicted classes
"""

rows, cols = probs.shape
preds = list(np.argmax(probs, axis=1))

if rows > cols or len(preds) == len(set(preds)):
# return 0 if rows > cols else 1, classes.take(preds)
return 0, classes.take(preds)
costs = np.log(probs)

try:
row_ind, col_ind = linear_sum_assignment(costs, maximize=True)
col_ind = list(col_ind)
except: # some of the values was -Inf
probs += np.nextafter(0, 1) # small double value
costs = np.log(probs)
row_ind, col_ind = linear_sum_assignment(costs, maximize=True)
col_ind = list(col_ind)

return 1, classes.take(col_ind)
25 changes: 24 additions & 1 deletion test/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from sslearn.base import FakedProbaClassifier, OneVsRestSSLClassifier
from sslearn.restricted import (WhoIsWhoClassifier, combine_predictions,
conflict_rate)
conflict_rate, probability_fusion, feature_fusion)
from sslearn.utils import (calc_number_per_class, calculate_prior_probability,
check_n_jobs, choice_with_proportion,
confidence_interval, is_int, safe_division)
Expand Down Expand Up @@ -129,4 +129,27 @@ def test_WhoIsWhoClassifier(self):
assert hyp.conflict_in_train == 1
assert hyp.predict(X, group).tolist() == [0, 1, 0, 2, 1]

def test_probability_fusion(self):
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([0, 1, 0, 1, 2])
cannot_link = {0: [0, 1], 1: [2, 3, 4]}
must_link = {1: [1, 3], 0: [0, 2], 4: [4]}

h = GaussianNB()
h.fit(X, y)

probability_fusion(h, X, must_link=must_link, cannot_link=cannot_link)

def test_feature_fusion(self):
X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
y = np.array([0, 1, 0, 1, 2])
cannot_link = {0: [0, 1], 1: [2, 3, 4]}
must_link = {1: [1, 3], 0: [0, 2], 4: [4]}

h = GaussianNB()
h.fit(X, y)

result = feature_fusion(h, X, must_link=must_link, cannot_link=cannot_link)



0 comments on commit a7dff8b

Please sign in to comment.