Skip to content

Commit

Permalink
Repair docs
Browse files Browse the repository at this point in the history
  • Loading branch information
jlgarridol committed May 2, 2024
1 parent bbb5774 commit 8190c6c
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 37 deletions.
1 change: 0 additions & 1 deletion sslearn/wrapper/_co.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,6 @@ class CoTrainingByCommittee(BaseCoTraining):
Pisa, 2008, pp. 563-572, [10.1109/ICDMW.2008.27](https://doi.org/10.1109/ICDMW.2008.27)
"""


def __init__(
self,
ensemble_estimator=BaggingClassifier(),
Expand Down
2 changes: 1 addition & 1 deletion sslearn/wrapper/_self.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __init__(
poolsize : float, optional
Max number of unlabel instances candidates to pseudolabel, by default 0.25
rejection_threshold : float, optional
significance level, by default 0.1
significance level, by default 0.05
graph_neighbors : int, optional
Number of neighbors for each sample., by default 1
random_state : int, RandomState instance, optional
Expand Down
70 changes: 35 additions & 35 deletions test/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sslearn.model_selection import artificial_ssl_dataset
from sslearn.wrapper import (
CoTraining, CoForest, CoTrainingByCommittee, DemocraticCoLearning, Rasco, RelRasco,
SelfTraining, Setred, TriTraining, WiWTriTraining, DeTriTraining
SelfTraining, Setred, TriTraining, DeTriTraining
)

X, y = read_csv(os.path.join(os.path.dirname(os.path.realpath(__file__)), "example_files", "abalone.csv"), format="pandas")
Expand Down Expand Up @@ -238,40 +238,40 @@ def test_all_label(self):
groups = np.array(groups)
groups = groups[:X.shape[0]]

class TestWiWTriTraining:
# class TestWiWTriTraining:

def test_basic(self):
clf = WiWTriTraining(base_estimator=DecisionTreeClassifier())
clf.fit(X, y, instance_group=groups)
clf.predict(X, instance_group=groups)
clf.predict_proba(X)

clf = WiWTriTraining(DecisionTreeClassifier())
clf.fit(X2, y2, instance_group=groups)
clf.predict(X2, instance_group=groups)
clf.predict_proba(X2)

def test_multiple(self):
clf = WiWTriTraining(base_estimator=[DecisionTreeClassifier(max_depth=1), DecisionTreeClassifier(max_depth=2), DecisionTreeClassifier(max_depth=3)])
clf.fit(X, y, instance_group=groups)
clf.predict(X, instance_group=groups)
clf.predict_proba(X)
# def test_basic(self):
# clf = WiWTriTraining(base_estimator=DecisionTreeClassifier())
# clf.fit(X, y, instance_group=groups)
# clf.predict(X, instance_group=groups)
# clf.predict_proba(X)

# clf = WiWTriTraining(DecisionTreeClassifier())
# clf.fit(X2, y2, instance_group=groups)
# clf.predict(X2, instance_group=groups)
# clf.predict_proba(X2)

# def test_multiple(self):
# clf = WiWTriTraining(base_estimator=[DecisionTreeClassifier(max_depth=1), DecisionTreeClassifier(max_depth=2), DecisionTreeClassifier(max_depth=3)])
# clf.fit(X, y, instance_group=groups)
# clf.predict(X, instance_group=groups)
# clf.predict_proba(X)

def test_random_state(self):
for i in range(10):
clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i)
clf.fit(X, y, instance_group=groups)
y1 = clf.predict(X, instance_group=groups)

clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i)
clf.fit(X, y, instance_group=groups)
y2 = clf.predict(X, instance_group=groups)

assert np.all(y1 == y2)

def test_all_label(self):
clf = WiWTriTraining(base_estimator=KNeighborsClassifier())
clf.fit(X, y, instance_group=groups)
clf.predict(X, instance_group=groups)
clf.predict_proba(X)
# def test_random_state(self):
# for i in range(10):
# clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i)
# clf.fit(X, y, instance_group=groups)
# y1 = clf.predict(X, instance_group=groups)

# clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i)
# clf.fit(X, y, instance_group=groups)
# y2 = clf.predict(X, instance_group=groups)

# assert np.all(y1 == y2)

# def test_all_label(self):
# clf = WiWTriTraining(base_estimator=KNeighborsClassifier())
# clf.fit(X, y, instance_group=groups)
# clf.predict(X, instance_group=groups)
# clf.predict_proba(X)

0 comments on commit 8190c6c

Please sign in to comment.