-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from jlgarridol/development
Version 1.0.5.1
- Loading branch information
Showing
13 changed files
with
4,461 additions
and
3,104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,7 @@ <h2>Submodules</h2> | |
<li><a href="sslearn/base.html">base</a></li> | ||
<li><a href="sslearn/datasets.html">datasets</a></li> | ||
<li><a href="sslearn/model_selection.html">model_selection</a></li> | ||
<li><a href="sslearn/restricted (Copia en conflicto de CROSS-PC 2024-05-14).html">restricted (Copia en conflicto de CROSS-PC 2024-05-14)</a></li> | ||
<li><a href="sslearn/restricted.html">restricted</a></li> | ||
<li><a href="sslearn/subview.html">subview</a></li> | ||
<li><a href="sslearn/utils.html">utils</a></li> | ||
|
@@ -162,7 +163,7 @@ <h2 id="citing">Citing</h2> | |
</span><span id="L-10"><a href="#L-10"><span class="linenos">10</span></a> <span class="vm">__doc__</span> <span class="o">=</span> <span class="s2">"Semi-Supervised Learning (SSL) is a Python package that provides tools to train and evaluate semi-supervised learning models."</span> | ||
</span><span id="L-11"><a href="#L-11"><span class="linenos">11</span></a> | ||
</span><span id="L-12"><a href="#L-12"><span class="linenos">12</span></a> | ||
</span><span id="L-13"><a href="#L-13"><span class="linenos">13</span></a><span class="n">__version__</span><span class="o">=</span><span class="s1">'1.0.5'</span> | ||
</span><span id="L-13"><a href="#L-13"><span class="linenos">13</span></a><span class="n">__version__</span><span class="o">=</span><span class="s1">'1.0.5.1'</span> | ||
</span><span id="L-14"><a href="#L-14"><span class="linenos">14</span></a><span class="n">__AUTHOR__</span><span class="o">=</span><span class="s2">"José Luis Garrido-Labrador"</span> <span class="c1"># Author of the package</span> | ||
</span><span id="L-15"><a href="#L-15"><span class="linenos">15</span></a><span class="n">__AUTHOR_EMAIL__</span><span class="o">=</span><span class="s2">"[email protected]"</span> <span class="c1"># Author's email</span> | ||
</span><span id="L-16"><a href="#L-16"><span class="linenos">16</span></a><span class="n">__URL__</span><span class="o">=</span><span class="s2">"https://pypi.org/project/sslearn/"</span> | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
985 changes: 985 additions & 0 deletions
985
docs/sslearn/restricted (Copia en conflicto de CROSS-PC 2024-05-14).html
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
__doc__ = "Semi-Supervised Learning (SSL) is a Python package that provides tools to train and evaluate semi-supervised learning models." | ||
|
||
|
||
__version__='1.0.5' | ||
__version__='1.0.5.1' | ||
__AUTHOR__="José Luis Garrido-Labrador" # Author of the package | ||
__AUTHOR_EMAIL__="[email protected]" # Author's email | ||
__URL__="https://pypi.org/project/sslearn/" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
208 changes: 208 additions & 0 deletions
208
sslearn/restricted (Copia en conflicto de CROSS-PC 2024-05-14).py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
"""Summary of module `sslearn.restricted`: | ||
This module contains classes to train a classifier using the restricted set classification approach. | ||
## Classes | ||
[WhoIsWhoClassifier](#WhoIsWhoClassifier): | ||
> Who is Who Classifier | ||
## Functions | ||
[conflict_rate](#conflict_rate): | ||
> Compute the conflict rate of a prediction, given a set of restrictions. | ||
[combine_predictions](#combine_predictions): | ||
> Combine the predictions of a group of instances to keep the restrictions. | ||
""" | ||
|
||
import numpy as np | ||
from sklearn.base import ClassifierMixin, MetaEstimatorMixin, BaseEstimator | ||
from scipy.optimize import linear_sum_assignment | ||
import warnings | ||
import pandas as pd | ||
|
||
__all__ = ["conflict_rate", "combine_predictions", "WhoIsWhoClassifier"] | ||
|
||
class WhoIsWhoClassifier(BaseEstimator, ClassifierMixin, MetaEstimatorMixin): | ||
|
||
def __init__(self, base_estimator, method="hungarian", conflict_weighted=True): | ||
""" | ||
Who is Who Classifier | ||
Kuncheva, L. I., Rodriguez, J. J., & Jackson, A. S. (2017). | ||
Restricted set classification: Who is there?. <i>Pattern Recognition</i>, 63, 158-170. | ||
Parameters | ||
---------- | ||
base_estimator : ClassifierMixin | ||
The base estimator to be used for training. | ||
method : str, optional | ||
The method to use to assing class, it can be `greedy` to first-look or `hungarian` to use the Hungarian algorithm, by default "hungarian" | ||
conflict_weighted : bool, default=True | ||
Whether to weighted the confusion rate by the number of instances with the same group. | ||
""" | ||
allowed_methods = ["greedy", "hungarian"] | ||
self.base_estimator = base_estimator | ||
self.method = method | ||
if method not in allowed_methods: | ||
raise ValueError(f"method {self.method} not supported, use one of {allowed_methods}") | ||
self.conflict_weighted = conflict_weighted | ||
|
||
|
||
def fit(self, X, y, instance_group=None, **kwards): | ||
"""Fit the model according to the given training data. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The input samples. | ||
y : array-like of shape (n_samples,) | ||
The target values. | ||
instance_group : array-like of shape (n_samples) | ||
The group. Two instances with the same label are not allowed to be in the same group. If None, group restriction will not be used in training. | ||
Returns | ||
------- | ||
self : object | ||
Returns self. | ||
""" | ||
self.base_estimator = self.base_estimator.fit(X, y, **kwards) | ||
self.classes_ = self.base_estimator.classes_ | ||
if instance_group is not None: | ||
self.conflict_in_train = conflict_rate(self.base_estimator.predict(X), instance_group, self.conflict_weighted) | ||
else: | ||
self.conflict_in_train = None | ||
return self | ||
|
||
def conflict_rate(self, X, instance_group): | ||
"""Calculate the conflict rate of the model. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The input samples. | ||
instance_group : array-like of shape (n_samples) | ||
The group. Two instances with the same label are not allowed to be in the same group. | ||
Returns | ||
------- | ||
float | ||
The conflict rate. | ||
""" | ||
y_pred = self.base_estimator.predict(X) | ||
return conflict_rate(y_pred, instance_group, self.conflict_weighted) | ||
|
||
def predict(self, X, instance_group): | ||
"""Predict class for X. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The input samples. | ||
**kwards : array-like of shape (n_samples) | ||
The group. Two instances with the same label are not allowed to be in the same group. | ||
Returns | ||
------- | ||
array-like of shape (n_samples, n_classes) | ||
The class probabilities of the input samples. | ||
""" | ||
|
||
y_prob = self.predict_proba(X) | ||
|
||
y_predicted = combine_predictions(y_prob, instance_group, len(self.classes_), self.method) | ||
|
||
return self.classes_.take(y_predicted) | ||
|
||
|
||
def predict_proba(self, X): | ||
"""Predict class probabilities for X. | ||
Parameters | ||
---------- | ||
X : {array-like, sparse matrix} of shape (n_samples, n_features) | ||
The input samples. | ||
Returns | ||
------- | ||
array-like of shape (n_samples, n_classes) | ||
The class probabilities of the input samples. | ||
""" | ||
return self.base_estimator.predict_proba(X) | ||
|
||
|
||
def conflict_rate(y_pred, restrictions, weighted=True): | ||
""" | ||
Computes the conflict rate of a prediction, given a set of restrictions. | ||
Parameters | ||
---------- | ||
y_pred : array-like of shape (n_samples,) | ||
Predicted target values. | ||
restrictions : array-like of shape (n_samples,) | ||
Restrictions for each sample. If two samples have the same restriction, they cannot have the same y. | ||
weighted : bool, default=True | ||
Whether to weighted the confusion rate by the number of instances with the same group. | ||
Returns | ||
------- | ||
conflict rate : float | ||
The conflict rate. | ||
""" | ||
|
||
# Check that y_pred and restrictions have the same length | ||
if len(y_pred) != len(restrictions): | ||
raise ValueError("y_pred and restrictions must have the same length.") | ||
|
||
restricted_df = pd.DataFrame({'y_pred': y_pred, 'restrictions': restrictions}) | ||
|
||
conflicted = restricted_df.groupby('restrictions').agg({'y_pred': lambda x: np.unique(x, return_counts=True)[1][np.unique(x, return_counts=True)[1]>1].sum()}) | ||
if weighted: | ||
return conflicted.sum().y_pred / len(y_pred) | ||
else: | ||
rcount = restricted_df.groupby('restrictions').count() | ||
return (conflicted.y_pred / rcount.y_pred).sum() | ||
|
||
def combine_predictions(y_probas, instance_group, class_number, method="hungarian"): | ||
y_predicted = [] | ||
for group in np.unique(instance_group): | ||
|
||
mask = instance_group == group | ||
probas_matrix = y_probas[mask] | ||
|
||
|
||
preds = list(np.argmax(probas_matrix, axis=1)) | ||
|
||
if len(preds) == len(set(preds)) or probas_matrix.shape[0] > class_number: | ||
y_predicted.extend(preds) | ||
if probas_matrix.shape[0] > class_number: | ||
warnings.warn("That the number of instances in the group is greater than the number of classes.", UserWarning) | ||
continue | ||
|
||
if method == "greedy": | ||
y = _greedy(probas_matrix) | ||
elif method == "hungarian": | ||
y = _hungarian(probas_matrix) | ||
|
||
y_predicted.extend(y) | ||
return y_predicted | ||
|
||
def _greedy(probas_matrix): | ||
|
||
probas = probas_matrix.reshape(probas_matrix.size,) | ||
order = probas.argsort()[::-1] | ||
|
||
y_pred_group = [None for i in range(probas_matrix.shape[0])] | ||
|
||
instance_to_predict = {i for i in range(probas_matrix.shape[0])} | ||
class_predicted = set() | ||
for item in order: | ||
class_ = item % probas_matrix.shape[0] | ||
instance = item // probas_matrix.shape[0] | ||
if instance in instance_to_predict and class_ not in class_predicted: | ||
y_pred_group[instance] = class_ | ||
instance_to_predict.remove(instance) | ||
class_predicted.add(class_) | ||
|
||
return y_pred_group | ||
|
||
|
||
def _hungarian(probas_matrix): | ||
|
||
costs = np.log(probas_matrix) | ||
costs[costs == -np.inf] = 0 # if proba is 0, then the cost is 0 | ||
_, col_ind = linear_sum_assignment(costs, maximize=True) | ||
col_ind = list(col_ind) | ||
|
||
return col_ind |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters