diff --git a/CHANGELOG.md b/CHANGELOG.md
index 7a8037e..63b9fc9 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,7 +4,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
-## [1.0.5] - XXXX-XX-XX
+## [1.0.5] - 2024-05-08
+
+### Added
+- `feature_fusion` and `probability_fusion` methods for restricted in `sslearn.restricted` module.
### Fixed
- CoForest random integer is now compatible with Windows.
diff --git a/sslearn/restricted.py b/sslearn/restricted.py
index 011be28..befb89c 100644
--- a/sslearn/restricted.py
+++ b/sslearn/restricted.py
@@ -21,9 +21,84 @@
from sklearn.base import ClassifierMixin, MetaEstimatorMixin, BaseEstimator
from scipy.optimize import linear_sum_assignment
import warnings
+from sklearn.base import clone
+from sklearn.feature_selection import VarianceThreshold
+from sklearn.pipeline import make_pipeline
+from sklearn.preprocessing import StandardScaler
+from .base import get_dataset
import pandas as pd
+from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
+
+__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier", "feature_fusion", "probability_fusion"]
+
+
+def feature_fusion(classifier, X, must_link, cannot_link):
+ """
+ Restricted Set Classification for the instances with pairwise constraints.
+ Combine all instances that have the must-link constraint with the average of their features.
+
+ **Parameters**
+ ----------
+ classifier : ClassifierMixin with predict_proba method
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
+ Array representing the data.
+ must_link : dict of {int: list of int}
+ Dictionary with the must links, where the key is the instance and the value is a list of instances that must have the same label.
+ cannot_link : dict of {int: list of int}
+ Dictionary with the cannot links, where the value is a list of instances that cannot have the same label.
+
+ **Returns**
+ ----------
+ y : ndarray of shape (n_samples,)
+ Array with predicted labels.
+
+ **References**
+ ----------
+ L.I. Kuncheva, J.L. Garrido-Labrador, I. Ramos-Pérez, S.L. Hennessey, J.J. Rodríguez (2024).
+ Semi-supervised classification with pairwise constraints: A case study on animal identification from video.
+ Information Fusion,
+ 104, 102188, [10.1016/j.inffus.2023.102188](https://doi.org/10.1016/j.inffus.2023.102188)
+ """
+
+ X_combined = __combine_features(X, must_link)
+ y_pred_proba = classifier.predict_proba(X_combined)
+
+ return __restricted_set_classification(y_pred_proba, cannot_link, classifier.classes_)
+
+
+def probability_fusion(classifier, X, must_link, cannot_link):
+ """
+ Restricted Set Classification for the instances with pairwise constraints.
+ The class probability for each instance is defined as the mean of the probabilities reported by the classifier according to the must-link constraint.
+
+ **Parameters**
+ ----------
+ classifier : ClassifierMixin with predict_proba method
+ X : {array-like, sparse matrix} of shape (n_samples, n_features)
+ Array representing the data.
+ must_link : dict of {int: list of int}
+ Dictionary with the must links, where the key is the instance and the value is a list of instances that must have the same label.
+ cannot_link : dict of {int: list of int}
+ Dictionary with the cannot links, where the value is a list of instances that cannot have the same label.
+
+ **Returns**
+ ----------
+ y : ndarray of shape (n_samples,)
+ Array with predicted labels.
+
+ **References**
+ ----------
+ L.I. Kuncheva, J.L. Garrido-Labrador, I. Ramos-Pérez, S.L. Hennessey, J.J. Rodríguez (2024).
+ Semi-supervised classification with pairwise constraints: A case study on animal identification from video.
+ Information Fusion,
+ 104, 102188, [10.1016/j.inffus.2023.102188](https://doi.org/10.1016/j.inffus.2023.102188)
+ """
+
+ y_probs = classifier.predict_proba(X)
+ classes = classifier.classes_
+ y_probs_combined, _ = __combine_probabilities(y_probs, must_link, classes)
+ return __restricted_set_classification(y_probs_combined, cannot_link, classes)
-__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier"]
class WhoIsWhoClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin):
@@ -205,4 +280,102 @@ def _hungarian(probas_matrix):
_, col_ind = linear_sum_assignment(costs, maximize=True)
col_ind = list(col_ind)
- return col_ind
\ No newline at end of file
+ return col_ind
+
+def __combine_probabilities(y_probs, objects_in_track, classes):
+ """"
+ Averages the classifier probabilities of the instances in the same track.
+
+ :param y_probs: classifier probabilities for the instances
+ :param objects_in_track: dictionary with the tracks
+ :param classes: classes used to train the classifier
+
+ :return: a tuple with the modified y_probs ans the predicted classes
+ """
+
+ y_probs_combined = y_probs.copy()
+
+ for objects in objects_in_track.values():
+ if len(objects) <= 1:
+ continue
+ means = y_probs_combined[objects, :].mean(axis=0)
+ y_probs_combined[objects, :] = means
+
+ preds = classes.take(list(np.argmax(y_probs_combined, axis=1)))
+ return y_probs_combined, preds
+
+def __combine_features(X, objects_in_track):
+ """
+ Averages the features of the instances in the same track.
+
+ :param X: feature values of the instances.
+ :param objects_in_track: dictionary with the tracks
+
+ :return: a modified X with averaged features
+ """
+
+ X_combined = X.copy()
+ is_df = isinstance(X, pd.DataFrame)
+ if is_df:
+ X_combined = X.values
+ for objects in objects_in_track.values():
+ if len(objects) <= 1:
+ continue
+ means = X_combined[objects].mean(axis=0)
+ X_combined[objects] = means
+ return X_combined
+
+def __restricted_set_classification(y_probs, instances_by_frame, classes):
+ """
+ Restricted Set Classification for the instances in several frames
+
+ :param y_probs: the probabilities given by the classifier for the instances
+ :param instances_by_frame: which instances are in each frame
+ :param classes: the classes seen by the classifier
+
+ :return: the predicted labels
+ """
+
+ restricted_pred = []
+ num_conflicts = 0
+ for fr, group in instances_by_frame.items():
+ if len(group) == 0:
+ continue
+ first, last = group[0], group[-1]
+ group_probs = y_probs[first:last + 1]
+ conflict, group_pred = __restricted_set_hungarian(group_probs, classes)
+ restricted_pred.extend(group_pred)
+ num_conflicts += conflict
+
+ assert len(restricted_pred) == len(y_probs), "The number of predictions is different from the number of instances, check cannot link constraints, all instances must be in a cannot-link group."
+
+ return restricted_pred
+
+def __restricted_set_hungarian(probs, classes):
+ """
+ Restricted Set Classification for a set of objects that have to be of different classes
+
+ :param probs: the probabilities given by the classifier
+ :param classes: the classes seen by the classifier
+
+ :return: a tuple with 1) the Hungarian method was used (0 or 1), and 2) the predicted classes
+ """
+
+ rows, cols = probs.shape
+ preds = list(np.argmax(probs, axis=1))
+
+ if rows > cols or len(preds) == len(set(preds)):
+ # return 0 if rows > cols else 1, classes.take(preds)
+ return 0, classes.take(preds)
+ costs = np.log(probs)
+
+ try:
+ row_ind, col_ind = linear_sum_assignment(costs, maximize=True)
+ col_ind = list(col_ind)
+ except: # some of the values was -Inf
+ probs += np.nextafter(0, 1) # small double value
+ costs = np.log(probs)
+ row_ind, col_ind = linear_sum_assignment(costs, maximize=True)
+ col_ind = list(col_ind)
+
+ return 1, classes.take(col_ind)
diff --git a/test/test_general.py b/test/test_general.py
index ed94462..2c2fa94 100644
--- a/test/test_general.py
+++ b/test/test_general.py
@@ -15,7 +15,7 @@
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
from sslearn.base import FakedProbaClassifier, OneVsRestSSLClassifier
from sslearn.restricted import (WhoIsWhoClassifier, combine_predictions,
- conflict_rate)
+ conflict_rate, probability_fusion, feature_fusion)
from sslearn.utils import (calc_number_per_class, calculate_prior_probability,
check_n_jobs, choice_with_proportion,
confidence_interval, is_int, safe_division)
@@ -129,4 +129,27 @@ def test_WhoIsWhoClassifier(self):
assert hyp.conflict_in_train == 1
assert hyp.predict(X, group).tolist() == [0, 1, 0, 2, 1]
+ def test_probability_fusion(self):
+ X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
+ y = np.array([0, 1, 0, 1, 2])
+ cannot_link = {0: [0, 1], 1: [2, 3, 4]}
+ must_link = {1: [1, 3], 0: [0, 2], 4: [4]}
+
+ h = GaussianNB()
+ h.fit(X, y)
+
+ probability_fusion(h, X, must_link=must_link, cannot_link=cannot_link)
+
+ def test_feature_fusion(self):
+ X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
+ y = np.array([0, 1, 0, 1, 2])
+ cannot_link = {0: [0, 1], 1: [2, 3, 4]}
+ must_link = {1: [1, 3], 0: [0, 2], 4: [4]}
+
+ h = GaussianNB()
+ h.fit(X, y)
+
+ result = feature_fusion(h, X, must_link=must_link, cannot_link=cannot_link)
+
+