Skip to content

Commit

Permalink
Classification metric fix [type:bug] (#260)
Browse files Browse the repository at this point in the history
* reuse encoders

* ensure categorical encoder is trained on real and synthetic

* better transformer

* remove unnecessary imports

* better error message

* compatbility with DDIM

* ensure classification metrics are always used when the task is classification

* fix mlp arguments typo

* add task_type to fix wrongful evaluation

* formatting

* ensure KFold works for regression

* remove debugging statements

* specify task type
  • Loading branch information
bvanbreugel authored Apr 15, 2024
1 parent 44cd36d commit 943fa28
Show file tree
Hide file tree
Showing 15 changed files with 39 additions and 39 deletions.
46 changes: 21 additions & 25 deletions src/synthcity/metrics/eval_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,8 @@ def _evaluate_performance_regression(
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _evaluate_standard_performance(
self,
clf_model: Any,
clf_args: Dict,
regression_model: Any,
regression_args: Any,
model: Any,
model_args: Dict,
X_gt: DataLoader,
X_syn: DataLoader,
) -> Dict:
Expand All @@ -179,17 +177,13 @@ def _evaluate_standard_performance(
ood_X_gt, ood_y_gt = X_gt.test().unpack()
iter_X_syn, iter_y_syn = X_syn.unpack()

if len(id_y_gt.unique()) < 5:
if self._task_type == "classification":
eval_cbk = self._evaluate_performance_classification
skf = StratifiedKFold(
n_splits=self._n_folds, shuffle=True, random_state=self._random_state
)
model = clf_model
model_args = clf_args
else:
elif self._task_type == "regression":
eval_cbk = self._evaluate_performance_regression
model = regression_model
model_args = regression_args
skf = KFold(
n_splits=self._n_folds, shuffle=True, random_state=self._random_state
)
Expand Down Expand Up @@ -680,19 +674,21 @@ def evaluate(
X_syn,
)
elif self._task_type == "classification" or self._task_type == "regression":
xgb_clf_args = {
if self._task_type == "classification":
model = XGBClassifier
else:
model = XGBRegressor

model_args = {
"n_jobs": 2,
"verbosity": 0,
"depth": 3,
"random_state": self._random_state,
}

xgb_reg_args = copy.deepcopy(xgb_clf_args)
return self._evaluate_standard_performance(
XGBClassifier,
xgb_clf_args,
XGBRegressor,
xgb_reg_args,
model,
model_args,
X_gt,
X_syn,
)
Expand Down Expand Up @@ -743,10 +739,15 @@ def evaluate(
) -> Dict:
if self._task_type == "survival_analysis":
return self._evaluate_survival_model(CoxPHSurvivalAnalysis, {}, X_gt, X_syn)
elif self._task_type == "classification" or self._task_type == "regression":
elif self._task_type == "classification":
return self._evaluate_standard_performance(
LogisticRegression,
{"random_state": self._random_state},
X_gt,
X_syn,
)
elif self._task_type == "regression":
return self._evaluate_standard_performance(
LinearRegression,
{},
X_gt,
Expand Down Expand Up @@ -887,21 +888,16 @@ def evaluate(
if X_gt.type() == "images":
return self._evaluate_images(X_gt, X_syn)

mlp_args = {
model_args = {
"n_units_in": X_gt.shape[1] - 1,
"n_units_out": 1,
"random_state": self._random_state,
"task_type": self._task_type,
}
clf_args = copy.deepcopy(mlp_args)
clf_args["task_type"] = "classification"
reg_args = copy.deepcopy(mlp_args)
reg_args["task_type"] = "regression"

return self._evaluate_standard_performance(
MLP,
clf_args,
MLP,
reg_args,
model_args,
X_gt,
X_syn,
)
Expand Down
4 changes: 3 additions & 1 deletion tests/plugins/core/models/test_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def _tracker(*args: Any, **kwargs: Any) -> float:
n_iter_print=20,
patience=2,
batch_size=len(X),
patience_metric=WeightedMetrics(metrics=[patience_metric], weights=[1]),
patience_metric=WeightedMetrics(
metrics=[patience_metric], weights=[1], task_type="regression"
),
generator_extra_penalty_cbks=[_tracker],
)
model.fit(X)
Expand Down
4 changes: 3 additions & 1 deletion tests/plugins/core/models/test_tabular_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ def test_gan_generation_with_early_stopping(patience_metric: Tuple[str, str]) ->
generator_n_iter=1000,
encoder_max_clusters=5,
patience=2,
patience_metric=WeightedMetrics(metrics=[patience_metric], weights=[1]),
patience_metric=WeightedMetrics(
metrics=[patience_metric], weights=[1], task_type="classification"
),
)
model.fit(X)

Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/domain_adaptation/test_radialgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def test_eval_performance_radialgan() -> None:

for retry in range(2):
test_plugin = plugin(n_iter=500, batch_size=50)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_arf.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_eval_performance_arf(compress_dataset: bool) -> None:

for retry in range(2):
test_plugin = plugin(**plugin_args)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate(count=100)
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_eval_performance_ctgan(compress_dataset: bool) -> None:
n_iter=5000,
compress_dataset=compress_dataset,
)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_eval_performance_ddpm(compress_dataset: bool) -> None:

for _ in range(2):
test_plugin = plugin(**plugin_params, compress_dataset=compress_dataset)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_great.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def test_eval_performance_great(compress_dataset: bool) -> None:

for retry in range(2):
test_plugin = plugin(**plugin_args)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate(count=100, max_length=100)
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_nflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_eval_performance_nflow(compress_dataset: bool) -> None:

for retry in range(2):
test_plugin = plugin(n_iter=5000, compress_dataset=compress_dataset)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_rtvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_eval_performance_rtvae() -> None:

for retry in range(2):
test_plugin = plugin(n_iter=1000)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/generic/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_eval_performance_tvae() -> None:

for retry in range(2):
test_plugin = plugin(n_iter=1000)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/privacy/test_adsgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_eval_performance(compress_dataset: bool) -> None:

for retry in range(2):
test_plugin = plugin(n_iter=5000, compress_dataset=compress_dataset)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/privacy/test_aim.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_eval_performance_aim(compress_dataset: bool) -> None:

for retry in range(2):
test_plugin = plugin(**plugin_args)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate(count=1000)
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/privacy/test_dpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_eval_performance_dpgan() -> None:

for retry in range(2):
test_plugin = plugin(n_iter=300)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/privacy/test_pategan.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_eval_performance() -> None:
test_plugin = plugin(
n_iter=200, generator_n_layers_hidden=1, n_teachers=2, lamda=2e-4
)
evaluator = PerformanceEvaluatorXGB()
evaluator = PerformanceEvaluatorXGB(task_type="classification")

test_plugin.fit(X)
X_syn = test_plugin.generate()
Expand Down

0 comments on commit 943fa28

Please sign in to comment.