Skip to content

Commit

Permalink
migrate pydantic (#295)
Browse files Browse the repository at this point in the history
* migrate pydantic 

* abstract n_folds

* tell bandit to ignore torch save/load warnings

* Update notebook tests
  • Loading branch information
robsdavis authored Oct 1, 2024
1 parent a856d6b commit d536d2e
Show file tree
Hide file tree
Showing 16 changed files with 883 additions and 315 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ jobs:
python -m pip install ipykernel
python -m ipykernel install --user
- name: Run the tutorials
run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests
run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests --timeout 3600
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ install_requires =
tenacity
tqdm
loguru
pydantic<2.0
pydantic
cloudpickle
scipy
xgboost<3.0.0
Expand Down
4 changes: 4 additions & 0 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def evaluate(
strict_augmentation: bool = False,
ad_hoc_augment_vals: Optional[Dict] = None,
use_metric_cache: bool = True,
n_eval_folds: int = 5,
**generate_kwargs: Any,
) -> pd.DataFrame:
"""Benchmark the performance of several algorithms.
Expand Down Expand Up @@ -102,6 +103,8 @@ def evaluate(
A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None.
use_metric_cache: bool
If the current metric has been previously run and is cached, it will be reused for the experiments. Defaults to True.
n_eval_folds: int
the KFolds used by MetricEvaluators in the benchmarks. Defaults to 5.
plugin_kwargs:
Optional kwargs for each algorithm. Example {"adsgan": {"n_iter": 10}},
"""
Expand Down Expand Up @@ -295,6 +298,7 @@ def evaluate(
task_type=task_type,
workspace=workspace,
use_cache=use_metric_cache,
n_folds=n_eval_folds,
)

mean_score = evaluation["mean"].to_dict()
Expand Down
6 changes: 3 additions & 3 deletions src/synthcity/metrics/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def f() -> None:
"epoch": epoch,
},
workspace / "DomiasMIA_bnaf_checkpoint.pt",
)
) # nosec B614

return f

Expand All @@ -348,7 +348,7 @@ def f() -> None:

log.info("Loading model..")
if (workspace / "checkpoint.pt").exists():
checkpoint = torch.load(workspace / "checkpoint.pt")
checkpoint = torch.load(workspace / "checkpoint.pt") # nosec B614
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])

Expand Down Expand Up @@ -453,7 +453,7 @@ def train(
"epoch": epoch,
},
workspace / "checkpoint.pt",
)
) # nosec B614
log.debug(
f"""
###### Stop training after {epoch + 1} epochs!
Expand Down
4 changes: 4 additions & 0 deletions src/synthcity/metrics/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def evaluate(
random_state: int = 0,
workspace: Path = Path("workspace"),
use_cache: bool = True,
n_folds: int = 5,
) -> pd.DataFrame:
"""Core evaluation logic for the metrics
Expand Down Expand Up @@ -238,6 +239,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt,
X_augmented,
Expand All @@ -251,6 +253,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt,
X_syn,
Expand All @@ -267,6 +270,7 @@ def evaluate(
random_state=random_state,
workspace=workspace,
use_cache=use_cache,
n_folds=n_folds,
),
X_gt.sample(eval_cnt),
X_syn.sample(eval_cnt),
Expand Down
10 changes: 6 additions & 4 deletions src/synthcity/plugins/core/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# third party
import numpy as np
import pandas as pd
from pydantic import BaseModel, validate_arguments, validator
from pydantic import BaseModel, field_validator, validate_arguments

# synthcity absolute
import synthcity.logger as log

Rule = Tuple[str, str, Any] # Define a type alias for clarity


class Constraints(BaseModel):
"""
Expand Down Expand Up @@ -41,10 +43,10 @@ class Constraints(BaseModel):
and thresh is the threshold or data type.
"""

rules: list = []
rules: list[Rule] = []

@validator("rules")
def _validate_rules(cls: Any, rules: List, values: dict, **kwargs: Any) -> List:
@field_validator("rules", mode="before")
def _validate_rules(cls: Any, rules: List) -> List:
supported_ops: list = [
"<",
">=",
Expand Down
Loading

0 comments on commit d536d2e

Please sign in to comment.