diff --git a/sslearn/wrapper/_co.py b/sslearn/wrapper/_co.py index 0e7c2eb..a01d67d 100644 --- a/sslearn/wrapper/_co.py +++ b/sslearn/wrapper/_co.py @@ -1102,7 +1102,6 @@ class CoTrainingByCommittee(BaseCoTraining): Pisa, 2008, pp. 563-572, [10.1109/ICDMW.2008.27](https://doi.org/10.1109/ICDMW.2008.27) """ - def __init__( self, ensemble_estimator=BaggingClassifier(), diff --git a/sslearn/wrapper/_self.py b/sslearn/wrapper/_self.py index f58d5eb..e18799d 100644 --- a/sslearn/wrapper/_self.py +++ b/sslearn/wrapper/_self.py @@ -197,7 +197,7 @@ def __init__( poolsize : float, optional Max number of unlabel instances candidates to pseudolabel, by default 0.25 rejection_threshold : float, optional - significance level, by default 0.1 + significance level, by default 0.05 graph_neighbors : int, optional Number of neighbors for each sample., by default 1 random_state : int, RandomState instance, optional diff --git a/test/test_wrapper.py b/test/test_wrapper.py index 85d7b11..a14092e 100644 --- a/test/test_wrapper.py +++ b/test/test_wrapper.py @@ -16,7 +16,7 @@ from sslearn.model_selection import artificial_ssl_dataset from sslearn.wrapper import ( CoTraining, CoForest, CoTrainingByCommittee, DemocraticCoLearning, Rasco, RelRasco, - SelfTraining, Setred, TriTraining, WiWTriTraining, DeTriTraining + SelfTraining, Setred, TriTraining, DeTriTraining ) X, y = read_csv(os.path.join(os.path.dirname(os.path.realpath(__file__)), "example_files", "abalone.csv"), format="pandas") @@ -238,40 +238,40 @@ def test_all_label(self): groups = np.array(groups) groups = groups[:X.shape[0]] -class TestWiWTriTraining: +# class TestWiWTriTraining: - def test_basic(self): - clf = WiWTriTraining(base_estimator=DecisionTreeClassifier()) - clf.fit(X, y, instance_group=groups) - clf.predict(X, instance_group=groups) - clf.predict_proba(X) - - clf = WiWTriTraining(DecisionTreeClassifier()) - clf.fit(X2, y2, instance_group=groups) - clf.predict(X2, instance_group=groups) - clf.predict_proba(X2) - - def test_multiple(self): - clf = WiWTriTraining(base_estimator=[DecisionTreeClassifier(max_depth=1), DecisionTreeClassifier(max_depth=2), DecisionTreeClassifier(max_depth=3)]) - clf.fit(X, y, instance_group=groups) - clf.predict(X, instance_group=groups) - clf.predict_proba(X) +# def test_basic(self): +# clf = WiWTriTraining(base_estimator=DecisionTreeClassifier()) +# clf.fit(X, y, instance_group=groups) +# clf.predict(X, instance_group=groups) +# clf.predict_proba(X) + +# clf = WiWTriTraining(DecisionTreeClassifier()) +# clf.fit(X2, y2, instance_group=groups) +# clf.predict(X2, instance_group=groups) +# clf.predict_proba(X2) + +# def test_multiple(self): +# clf = WiWTriTraining(base_estimator=[DecisionTreeClassifier(max_depth=1), DecisionTreeClassifier(max_depth=2), DecisionTreeClassifier(max_depth=3)]) +# clf.fit(X, y, instance_group=groups) +# clf.predict(X, instance_group=groups) +# clf.predict_proba(X) - def test_random_state(self): - for i in range(10): - clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i) - clf.fit(X, y, instance_group=groups) - y1 = clf.predict(X, instance_group=groups) - - clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i) - clf.fit(X, y, instance_group=groups) - y2 = clf.predict(X, instance_group=groups) - - assert np.all(y1 == y2) - - def test_all_label(self): - clf = WiWTriTraining(base_estimator=KNeighborsClassifier()) - clf.fit(X, y, instance_group=groups) - clf.predict(X, instance_group=groups) - clf.predict_proba(X) +# def test_random_state(self): +# for i in range(10): +# clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i) +# clf.fit(X, y, instance_group=groups) +# y1 = clf.predict(X, instance_group=groups) + +# clf = WiWTriTraining(base_estimator=KNeighborsClassifier(), random_state=i) +# clf.fit(X, y, instance_group=groups) +# y2 = clf.predict(X, instance_group=groups) + +# assert np.all(y1 == y2) + +# def test_all_label(self): +# clf = WiWTriTraining(base_estimator=KNeighborsClassifier()) +# clf.fit(X, y, instance_group=groups) +# clf.predict(X, instance_group=groups) +# clf.predict_proba(X)