From d536d2ef02d378eb80b96ad9fd13c6eb9da705b5 Mon Sep 17 00:00:00 2001 From: Rob <62107751+robsdavis@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:58:16 +0100 Subject: [PATCH] migrate pydantic (#295) * migrate pydantic * abstract n_folds * tell bandit to ignore torch save/load warnings * Update notebook tests --- .github/workflows/test_tutorials.yml | 2 +- setup.cfg | 2 +- src/synthcity/benchmark/__init__.py | 4 + src/synthcity/metrics/_utils.py | 6 +- src/synthcity/metrics/eval.py | 4 + src/synthcity/plugins/core/constraints.py | 10 +- src/synthcity/plugins/core/distribution.py | 654 ++++++++++++++---- .../plugins/core/models/tabular_encoder.py | 13 +- src/synthcity/plugins/core/plugin.py | 12 +- src/synthcity/plugins/core/schema.py | 366 ++++++---- tests/benchmarks/test_benchmarks.py | 9 +- tests/metrics/test_api.py | 10 + tests/nb_eval.py | 45 +- tests/plugins/core/test_distribution.py | 20 +- tests/plugins/core/test_schema.py | 39 +- .../plugins/time_series/plugin_timegan.ipynb | 2 +- 16 files changed, 883 insertions(+), 315 deletions(-) diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml index 1b1a6587..ef1cfbd6 100644 --- a/.github/workflows/test_tutorials.yml +++ b/.github/workflows/test_tutorials.yml @@ -40,4 +40,4 @@ jobs: python -m pip install ipykernel python -m ipykernel install --user - name: Run the tutorials - run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests + run: python tests/nb_eval.py --nb_dir tutorials/ --tutorial_tests minimal_tests --timeout 3600 diff --git a/setup.cfg b/setup.cfg index 62c88501..1aaba156 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,7 +47,7 @@ install_requires = tenacity tqdm loguru - pydantic<2.0 + pydantic cloudpickle scipy xgboost<3.0.0 diff --git a/src/synthcity/benchmark/__init__.py b/src/synthcity/benchmark/__init__.py index 844882f1..1e583b98 100644 --- a/src/synthcity/benchmark/__init__.py +++ b/src/synthcity/benchmark/__init__.py @@ -57,6 +57,7 @@ def evaluate( strict_augmentation: bool = False, ad_hoc_augment_vals: Optional[Dict] = None, use_metric_cache: bool = True, + n_eval_folds: int = 5, **generate_kwargs: Any, ) -> pd.DataFrame: """Benchmark the performance of several algorithms. @@ -102,6 +103,8 @@ def evaluate( A dictionary containing the number of each class to augment the real data with. This is only required if using the rule="ad-hoc" option. Defaults to None. use_metric_cache: bool If the current metric has been previously run and is cached, it will be reused for the experiments. Defaults to True. + n_eval_folds: int + the KFolds used by MetricEvaluators in the benchmarks. Defaults to 5. plugin_kwargs: Optional kwargs for each algorithm. Example {"adsgan": {"n_iter": 10}}, """ @@ -295,6 +298,7 @@ def evaluate( task_type=task_type, workspace=workspace, use_cache=use_metric_cache, + n_folds=n_eval_folds, ) mean_score = evaluation["mean"].to_dict() diff --git a/src/synthcity/metrics/_utils.py b/src/synthcity/metrics/_utils.py index 6aecf048..7e1cd77b 100644 --- a/src/synthcity/metrics/_utils.py +++ b/src/synthcity/metrics/_utils.py @@ -332,7 +332,7 @@ def f() -> None: "epoch": epoch, }, workspace / "DomiasMIA_bnaf_checkpoint.pt", - ) + ) # nosec B614 return f @@ -348,7 +348,7 @@ def f() -> None: log.info("Loading model..") if (workspace / "checkpoint.pt").exists(): - checkpoint = torch.load(workspace / "checkpoint.pt") + checkpoint = torch.load(workspace / "checkpoint.pt") # nosec B614 model.load_state_dict(checkpoint["model"]) optimizer.load_state_dict(checkpoint["optimizer"]) @@ -453,7 +453,7 @@ def train( "epoch": epoch, }, workspace / "checkpoint.pt", - ) + ) # nosec B614 log.debug( f""" ###### Stop training after {epoch + 1} epochs! diff --git a/src/synthcity/metrics/eval.py b/src/synthcity/metrics/eval.py index c6d0fbd3..416aa989 100644 --- a/src/synthcity/metrics/eval.py +++ b/src/synthcity/metrics/eval.py @@ -119,6 +119,7 @@ def evaluate( random_state: int = 0, workspace: Path = Path("workspace"), use_cache: bool = True, + n_folds: int = 5, ) -> pd.DataFrame: """Core evaluation logic for the metrics @@ -238,6 +239,7 @@ def evaluate( random_state=random_state, workspace=workspace, use_cache=use_cache, + n_folds=n_folds, ), X_gt, X_augmented, @@ -251,6 +253,7 @@ def evaluate( random_state=random_state, workspace=workspace, use_cache=use_cache, + n_folds=n_folds, ), X_gt, X_syn, @@ -267,6 +270,7 @@ def evaluate( random_state=random_state, workspace=workspace, use_cache=use_cache, + n_folds=n_folds, ), X_gt.sample(eval_cnt), X_syn.sample(eval_cnt), diff --git a/src/synthcity/plugins/core/constraints.py b/src/synthcity/plugins/core/constraints.py index dc79e56b..693e4144 100644 --- a/src/synthcity/plugins/core/constraints.py +++ b/src/synthcity/plugins/core/constraints.py @@ -4,11 +4,13 @@ # third party import numpy as np import pandas as pd -from pydantic import BaseModel, validate_arguments, validator +from pydantic import BaseModel, field_validator, validate_arguments # synthcity absolute import synthcity.logger as log +Rule = Tuple[str, str, Any] # Define a type alias for clarity + class Constraints(BaseModel): """ @@ -41,10 +43,10 @@ class Constraints(BaseModel): and thresh is the threshold or data type. """ - rules: list = [] + rules: list[Rule] = [] - @validator("rules") - def _validate_rules(cls: Any, rules: List, values: dict, **kwargs: Any) -> List: + @field_validator("rules", mode="before") + def _validate_rules(cls: Any, rules: List) -> List: supported_ops: list = [ "<", ">=", diff --git a/src/synthcity/plugins/core/distribution.py b/src/synthcity/plugins/core/distribution.py index 788db4e0..8f5febc8 100644 --- a/src/synthcity/plugins/core/distribution.py +++ b/src/synthcity/plugins/core/distribution.py @@ -1,23 +1,33 @@ # stdlib from abc import ABCMeta, abstractmethod -from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from datetime import datetime, timedelta, timezone +from typing import Any, List, Optional, Tuple # third party import numpy as np import pandas as pd -from pydantic import BaseModel, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + FieldValidationInfo, + PrivateAttr, + ValidationInfo, + field_validator, + model_validator, +) # synthcity absolute from synthcity.plugins.core.constraints import Constraints +Rule = Tuple[str, str, Any] # Define a type alias for clarity + class Distribution(BaseModel, metaclass=ABCMeta): """ .. inheritance-diagram:: synthcity.plugins.core.distribution.Distribution :parts: 1 - Base class of all Distributions. The Distribution class characterizes the **empirical** marginal distribution of the feature. @@ -37,19 +47,22 @@ class Distribution(BaseModel, metaclass=ABCMeta): name: str data: Optional[pd.Series] = None - random_state: int = 0 + random_state: Optional[int] = None + sampling_strategy: str = "marginal" + _rng: np.random.Generator = PrivateAttr() # DP parameters marginal_distribution: Optional[pd.Series] = None - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) - @validator("marginal_distribution", always=True) - def _validate_marginal_distribution(cls: Any, v: Any, values: Dict) -> Dict: - if "data" not in values or values["data"] is None: + @field_validator("marginal_distribution", mode="before") + def _validate_marginal_distribution( + cls: Any, v: Any, values: FieldValidationInfo + ) -> Optional[pd.Series]: + if "data" not in values.data or values.data["data"] is None: return v - data = values["data"] + data = values.data["data"] if not isinstance(data, pd.Series): raise ValueError(f"Invalid data type {type(data)}") @@ -58,6 +71,17 @@ def _validate_marginal_distribution(cls: Any, v: Any, values: Dict) -> Dict: return marginal + @model_validator(mode="after") + def initialize_rng(cls, model: "Distribution") -> "Distribution": + """ + Initializes the random number generator after model validation. + """ + if model.random_state is not None: + model._rng = np.random.default_rng(model.random_state) + else: + model._rng = np.random.default_rng() + return model + def marginal_states(self) -> Optional[List]: if self.marginal_distribution is None: return None @@ -73,12 +97,10 @@ def marginal_probabilities(self) -> Optional[List]: ) def sample_marginal(self, count: int = 1) -> Any: - np.random.seed(self.random_state) - if self.marginal_distribution is None: return None - return np.random.choice( + return self._rng.choice( self.marginal_states(), count, p=self.marginal_probabilities(), @@ -142,49 +164,117 @@ class CategoricalDistribution(Distribution): :parts: 1 """ - choices: list = [] + data: Optional[pd.Series] = None + marginal_distribution: Optional[pd.Series] = None + choices: List[Any] = Field(default_factory=list) - @validator("choices", always=True) - def _validate_choices(cls: Any, v: List, values: Dict) -> List: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - return list(values[mkey].index) + @model_validator(mode="after") + def validate_and_initialize( + cls, model: "CategoricalDistribution" + ) -> "CategoricalDistribution": + """ + Validates and initializes choices and marginal_distribution based on data or provided choices. + Ensures that choices are unique and sorted. + """ + if model.data is not None: + # Set marginal_distribution based on data + model.marginal_distribution = model.data.value_counts(normalize=True) + model.choices = model.marginal_distribution.index.tolist() + elif model.choices is not None: + # Ensure choices are unique and sorted + model.choices = sorted(set(model.choices)) + # Set uniform probabilities + probabilities = np.ones(len(model.choices)) / len(model.choices) + model.marginal_distribution = pd.Series(probabilities, index=model.choices) + else: + raise ValueError( + "Invalid CategoricalDistribution: Provide either 'data' or 'choices'." + ) - if len(v) == 0: + # Additional validation to ensure consistency + if not isinstance(model.choices, list) or len(model.choices) == 0: raise ValueError( - "Invalid choices for CategoricalDistribution. Provide data or choices params" + "CategoricalDistribution must have a non-empty 'choices' list." + ) + if not isinstance(model.marginal_distribution, pd.Series): + raise ValueError( + "CategoricalDistribution must have a valid 'marginal_distribution'." + ) + if len(model.choices) != len(model.marginal_distribution): + raise ValueError( + "'choices' and 'marginal_distribution' must have the same length." ) - return sorted(set(v)) - def get(self) -> List[Any]: - return [self.name, self.choices] + return model def sample(self, count: int = 1) -> Any: - np.random.seed(self.random_state) - msamples = self.sample_marginal(count) - if msamples is not None: - return msamples + """ + Samples values from the distribution based on the specified sampling strategy. + If the distribution has only one choice, returns an array filled with that value. + """ + if self.choices is not None and len(self.choices) == 1: + samples = np.full(count, self.choices[0]) + else: + if self.sampling_strategy == "marginal": + if self.marginal_distribution is None: + raise ValueError( + "Cannot sample based on marginal distribution: marginal_distribution is not provided." + ) + return self._rng.choice( + self.marginal_distribution.index, + size=count, + p=self.marginal_distribution.values, + ) + elif self.sampling_strategy == "uniform": + return self._rng.choice(self.choices, size=count) + else: + raise ValueError( + f"Unsupported sampling strategy '{self.sampling_strategy}'." + ) + return samples - return np.random.choice(self.choices, count) + def get(self) -> List[Any]: + """ + Returns the metadata of the distribution. + """ + return [self.name, self.choices] def has(self, val: Any) -> bool: + """ + Checks if a value is among the distribution's choices. + """ return val in self.choices def includes(self, other: "Distribution") -> bool: + """ + Checks if another categorical distribution's choices are a subset of this distribution's choices. + """ if not isinstance(other, CategoricalDistribution): return False return set(other.choices).issubset(set(self.choices)) def as_constraint(self) -> Constraints: + """ + Converts the distribution to a set of constraints. + """ return Constraints(rules=[(self.name, "in", list(self.choices))]) def min(self) -> Any: + """ + Returns the minimum value among the choices. + """ return min(self.choices) def max(self) -> Any: + """ + Returns the maximum value among the choices. + """ return max(self.choices) def dtype(self) -> str: + """ + Determines the data type based on the choices. + """ types = { "object": 0, "float": 0, @@ -211,42 +301,112 @@ class FloatDistribution(Distribution): :parts: 1 """ - low: float = np.finfo(np.float64).min - high: float = np.finfo(np.float64).max + low: Optional[float] = Field(default=None) + high: Optional[float] = Field(default=None) + _is_constant: bool = PrivateAttr(False) - @validator("low", always=True) - def _validate_low_thresh(cls: Any, v: float, values: Dict) -> float: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - return values[mkey].index.min() + model_config = ConfigDict(arbitrary_types_allowed=True) - return v + @model_validator(mode="after") + def validate_and_initialize(cls, model: "FloatDistribution") -> "FloatDistribution": + """ + Validates and initializes the distribution. + Sets '_is_constant' based on whether 'low' equals 'high'. + Initializes 'marginal_distribution' based on 'data' if provided. + """ + if model.data is not None: + # Initialize marginal_distribution based on data + # For float data, use value_counts(normalize=True) if data has repeated values + # This will create a discrete approximation of the distribution + model.marginal_distribution = model.data.value_counts( + normalize=True + ).sort_index() + model.low = float(model.data.min()) + model.high = float(model.data.max()) + elif model.marginal_distribution is not None: + # Set 'low' and 'high' based on marginal_distribution + model.low = float(model.marginal_distribution.index.min()) + model.high = float(model.marginal_distribution.index.max()) + else: + # Ensure 'low' and 'high' are provided + if model.low is None or model.high is None: + raise ValueError( + "FloatDistribution requires 'low' and 'high' values if 'data' or 'marginal_distribution' is not provided." + ) + + # Validate that low <= high + if model.low > model.high: + raise ValueError( + f"Invalid range for '{model.name}': low ({model.low}) cannot be greater than high ({model.high})." + ) - @validator("high", always=True) - def _validate_high_thresh(cls: Any, v: float, values: Dict) -> float: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - return values[mkey].index.max() + # Set _is_constant based on low == high + model._is_constant = model.low == model.high - return v + # Ensure that low and high are finite numbers + if not np.isfinite(model.low) or not np.isfinite(model.high): + raise ValueError( + f"Invalid range for '{model.name}': low or high is not finite (low={model.low}, high={model.high})." + ) - def get(self) -> List[Any]: - return [self.name, self.low, self.high] + return model def sample(self, count: int = 1) -> Any: - np.random.seed(self.random_state) - msamples = self.sample_marginal(count) - if msamples is not None: - return msamples - return np.random.uniform(self.low, self.high, count) + """ + Samples values from the distribution. + If the distribution is constant, returns an array filled with the constant value. + Otherwise, samples based on the marginal distribution or uniform sampling. + """ + if self._is_constant: + if self.low is None: + raise ValueError( + "Cannot sample: 'low' is None for a constant distribution." + ) + samples = np.full(count, self.low) + else: + if self.low is None or self.high is None: + raise ValueError("Cannot sample: 'low' or 'high' is None.") + if ( + self.sampling_strategy == "marginal" + and self.marginal_distribution is not None + ): + # Sample based on marginal distribution + return self._rng.choice( + self.marginal_distribution.index.values, + size=count, + p=self.marginal_distribution.values, + ) + else: + # Proceed with uniform sampling + samples = self._rng.uniform(low=self.low, high=self.high, size=count) + return samples + + def get(self) -> List[Any]: + """ + Returns the metadata of the distribution. + """ + return [self.name, self.low, self.high] def has(self, val: Any) -> bool: - return self.low <= val and val <= self.high + """ + Checks if a value is within the distribution's range. + """ + return self.low <= val <= self.high def includes(self, other: "Distribution") -> bool: + """ + Checks if another distribution is entirely within this distribution. + """ + if self.min() is None or self.max() is None: + return False + if other.min() is None or other.max() is None: + return False return self.min() <= other.min() and other.max() <= self.max() def as_constraint(self) -> Constraints: + """ + Converts the distribution to a set of constraints. + """ return Constraints( rules=[ (self.name, "le", self.high), @@ -256,12 +416,21 @@ def as_constraint(self) -> Constraints: ) def min(self) -> Any: + """ + Returns the minimum value of the distribution. + """ return self.low def max(self) -> Any: + """ + Returns the maximum value of the distribution. + """ return self.high def dtype(self) -> str: + """ + Returns the data type of the distribution. + """ return "float" @@ -273,12 +442,11 @@ def get(self) -> List[Any]: return [self.name, self.low, self.high] def sample(self, count: int = 1) -> Any: - np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples lo, hi = np.log2(self.low), np.log2(self.high) - return 2.0 ** np.random.uniform(lo, hi, count) + return 2.0 ** self._rng.uniform(lo, hi, count) class IntegerDistribution(Distribution): @@ -287,75 +455,167 @@ class IntegerDistribution(Distribution): :parts: 1 """ - low: int = np.iinfo(np.int64).min - high: int = np.iinfo(np.int64).max - step: int = 1 + low: Optional[int] = Field(default=None) + high: Optional[int] = Field(default=None) + step: int = Field(default=1) + _is_constant: bool = PrivateAttr(False) - @validator("low", always=True) - def _validate_low_thresh(cls: Any, v: int, values: Dict) -> int: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - return int(values[mkey].index.min()) + model_config = ConfigDict(arbitrary_types_allowed=True) - return v + @model_validator(mode="after") + def validate_and_initialize( + cls, model: "IntegerDistribution" + ) -> "IntegerDistribution": + """ + Validates and initializes the distribution. + Sets '_is_constant' based on whether 'low' equals 'high'. + Initializes 'marginal_distribution' based on 'data' if provided. + """ + if model.data is not None: + # Initialize marginal_distribution based on data + model.marginal_distribution = model.data.value_counts( + normalize=True + ).sort_index() + model.low = int(model.data.min()) + model.high = int(model.data.max()) + elif model.marginal_distribution is not None: + # Infer 'low' and 'high' from the marginal distribution's index + model.low = int(model.marginal_distribution.index.min()) + model.high = int(model.marginal_distribution.index.max()) + else: + # Ensure 'low' and 'high' are provided + if model.low is None or model.high is None: + raise ValueError( + "IntegerDistribution requires 'low' and 'high' values if 'data' or 'marginal_distribution' is not provided." + ) + + # Validate that low <= high + if model.low > model.high: + raise ValueError( + f"Invalid range for '{model.name}': low ({model.low}) cannot be greater than high ({model.high})." + ) - @validator("high", always=True) - def _validate_high_thresh(cls: Any, v: int, values: Dict) -> int: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - return int(values[mkey].index.max()) - return v + # Set _is_constant based on low == high + model._is_constant = model.low == model.high - @validator("step", always=True) - def _validate_step(cls: Any, v: int, values: Dict) -> int: - if v < 1: - raise ValueError("Step must be greater than 0") - return v + # Ensure that low and high are finite integers + if not np.isfinite(model.low) or not np.isfinite(model.high): + raise ValueError( + f"Invalid range for '{model.name}': low or high is not finite (low={model.low}, high={model.high})." + ) - def get(self) -> List[Any]: - return [self.name, self.low, self.high, self.step] + # Ensure that 'step' is a positive integer + if model.step <= 0: + raise ValueError("'step' must be a positive integer.") + + # Adjust 'low' and 'high' to be compatible with 'step' + model.low = model.low - ((model.low - (model.low % model.step)) % model.step) + model.high = model.high - ( + (model.high - (model.high % model.step)) % model.step + ) + + # Re-validate after adjustment + if model.low > model.high: + raise ValueError( + f"After adjusting with step, invalid range for '{model.name}': low ({model.low}) cannot be greater than high ({model.high})." + ) + + return model def sample(self, count: int = 1) -> Any: - np.random.seed(self.random_state) - msamples = self.sample_marginal(count) - if msamples is not None: - return msamples + """ + Samples values from the distribution. + If the distribution is constant, returns an array filled with the constant value. + Otherwise, samples based on the marginal distribution or uniform sampling. + """ + if self._is_constant: + if self.low is None: + raise ValueError( + "Cannot sample: 'low' is None for a constant distribution." + ) + samples = np.full(count, self.low) + else: + if self.low is None or self.high is None: + raise ValueError("Cannot sample: 'low' or 'high' is None.") + if ( + self.sampling_strategy == "marginal" + and self.marginal_distribution is not None + ): + # Sample based on marginal distribution + return self._rng.choice( + self.marginal_distribution.index, + size=count, + p=self.marginal_distribution.values, + ) + else: + if self.low is None or self.high is None: + raise ValueError( + "Cannot sample based on uniform distribution: low or high is not provided." + ) + # Proceed with uniform sampling + possible_values = np.arange(self.low, self.high + 1, self.step) + samples = self._rng.choice(possible_values, size=count) + return samples - steps = (self.high - self.low) // self.step - samples = np.random.choice(steps + 1, count) - return samples * self.step + self.low + def get(self) -> List[Any]: + """ + Returns the metadata of the distribution. + """ + return [self.name, self.low, self.high, self.step] def has(self, val: Any) -> bool: - return self.low <= val and val <= self.high + """ + Checks if a value is within the distribution's range. + """ + return self.low <= val <= self.high def includes(self, other: "Distribution") -> bool: + """ + Checks if another distribution is entirely within this distribution. + """ + if self.min() is None or self.max() is None: + return False + if other.min() is None or other.max() is None: + return False return self.min() <= other.min() and other.max() <= self.max() def as_constraint(self) -> Constraints: - return Constraints( - rules=[ - (self.name, "le", self.high), - (self.name, "ge", self.low), - (self.name, "dtype", "int"), - ] - ) + """ + Converts the distribution to a set of constraints. + """ + rules: List[Rule] = [] + if self.low is not None: + rules.append((self.name, "ge", self.low)) + if self.high is not None: + rules.append((self.name, "le", self.high)) + rules.append((self.name, "dtype", "int")) + return Constraints(rules=rules) def min(self) -> Any: + """ + Returns the minimum value of the distribution. + """ return self.low def max(self) -> Any: + """ + Returns the maximum value of the distribution. + """ return self.high def dtype(self) -> str: + """ + Returns the data type of the distribution. + """ return "int" class IntLogDistribution(IntegerDistribution): - low: int = 1 - high: int = np.iinfo(np.int64).max + low: int = Field(default=1) + high: int = Field(default=np.iinfo(np.int64).max) - @validator("step", always=True) - def _validate_step(cls: Any, v: int, values: Dict) -> int: + @field_validator("step", mode="before") + def _validate_step(cls: Any, v: int, values: ValidationInfo) -> int: if v != 1: raise ValueError("Step must be 1 for IntLogDistribution") return v @@ -364,12 +624,11 @@ def get(self) -> List[Any]: return [self.name, self.low, self.high] def sample(self, count: int = 1) -> Any: - np.random.seed(self.random_state) msamples = self.sample_marginal(count) if msamples is not None: return msamples lo, hi = np.log2(self.low), np.log2(self.high) - samples = 2.0 ** np.random.uniform(lo, hi, count) + samples = 2.0 ** self._rng.uniform(lo, hi, count) return samples.astype(int) @@ -379,48 +638,126 @@ class DatetimeDistribution(Distribution): :parts: 1 """ - low: datetime = datetime.utcfromtimestamp(0) - high: datetime = datetime.now() - step: timedelta = timedelta(microseconds=1) - offset: timedelta = timedelta(seconds=120) + low: Optional[datetime] = Field(default=None) + high: Optional[datetime] = Field(default=None) + step: timedelta = Field(default=timedelta(microseconds=1)) + offset: timedelta = Field(default=timedelta(seconds=120)) + _is_constant: bool = PrivateAttr(False) # Correctly named with leading underscore - @validator("low", always=True) - def _validate_low_thresh(cls: Any, v: datetime, values: Dict) -> datetime: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - v = values[mkey].index.min() - return v + model_config = ConfigDict(arbitrary_types_allowed=True) - @validator("high", always=True) - def _validate_high_thresh(cls: Any, v: datetime, values: Dict) -> datetime: - mkey = "marginal_distribution" - if mkey in values and values[mkey] is not None: - v = values[mkey].index.max() - return v + @model_validator(mode="after") + def validate_low_high(cls, model: "DatetimeDistribution") -> "DatetimeDistribution": + """ + Validates that 'low' is less than or equal to 'high'. + Sets '_is_constant' based on whether 'low' equals 'high'. + """ + if model.marginal_distribution is not None: + # Infer 'low' and 'high' from the marginal distribution's index + model.low = model.marginal_distribution.index.min() + model.high = model.marginal_distribution.index.max() + else: + # If 'marginal_distribution' is not provided, ensure 'low' and 'high' are set + if model.low is None or model.high is None: + if model.data is not None: + model.low = model.data.min() + model.high = model.data.max() + else: + # Set default finite datetime values if not provided + model.low = datetime.fromtimestamp(0, timezone.utc) + model.high = datetime.now() + if model.low is None or model.high is None: + raise ValueError( + "DatetimeDistribution requires 'low' and 'high' values if 'data' or 'marginal_distribution' is not provided." + ) + # Validate that low <= high + if model.low > model.high: + raise ValueError( + f"Invalid range for {model.name}: low ({model.low}) cannot be greater than high ({model.high})." + ) - def get(self) -> List[Any]: - return [self.name, self.low, self.high, self.step, self.offset] + # Set _is_constant based on low == high + model._is_constant = model.low == model.high + + # Ensure that low and high are valid datetime objects + if not isinstance(model.low, datetime) or not isinstance(model.high, datetime): + raise ValueError( + f"Invalid range for {model.name}: low or high is not a valid datetime object (low={model.low}, high={model.high})." + ) + + # Ensure that 'step' is positive and non-zero + if model.step.total_seconds() <= 0: + raise ValueError("'step' must be a positive timedelta.") + + return model def sample(self, count: int = 1) -> Any: - np.random.seed(self.random_state) - msamples = self.sample_marginal(count) - if msamples is not None: - return msamples + """ + Samples datetime values from the distribution. + If the distribution is constant, returns a list filled with the constant datetime value. + Otherwise, samples based on the specified sampling strategy. + """ + if self._is_constant: + if self.low is None: + raise ValueError( + "Cannot sample constant datetime distribution: low is not provided." + ) + samples = [self.low for _ in range(count)] + else: + if self.low is None or self.high is None: + raise ValueError( + "Cannot sample datetime distribution: low or high is not provided." + ) + if self.sampling_strategy in ["marginal", "uniform"]: + msamples = self.sample_marginal(count) + if msamples is not None: + return msamples + if self.low is None or self.high is None: + raise ValueError( + "Cannot sample based on marginal distribution: low or high is not provided." + ) + total_seconds = (self.high - self.low).total_seconds() + step_seconds = self.step.total_seconds() + steps = int(total_seconds / step_seconds) + step_indices = self._rng.integers(0, steps + 1, count) + samples = [self.low + self.step * int(s) for s in step_indices] + else: + raise ValueError( + f"Unsupported sampling strategy '{self.sampling_strategy}'." + ) + return samples - n = (self.high - self.low) // self.step + 1 - samples = np.round(np.random.rand(count) * n - 0.5) - return self.low + samples * self.step + def get(self) -> List[Any]: + """ + Returns the metadata of the distribution. + """ + return [self.name, self.low, self.high, self.step, self.offset] def has(self, val: datetime) -> bool: - return self.low <= val and val <= self.high + """ + Checks if a datetime value is within the distribution's range. + """ + if self.low is None or self.high is None: + raise ValueError("Cannot determine 'has' because 'low' or 'high' is None.") + return self.low <= val <= self.high def includes(self, other: "Distribution") -> bool: + """ + Checks if another datetime distribution is entirely within this distribution, considering the offset. + """ + if self.low is None or self.high is None: + return False + if other.min() is None or other.max() is None: + return False return ( - self.min() - self.offset <= other.min() - and other.max() <= self.max() + self.offset + self.low - self.offset <= other.min() + and other.max() <= self.high + self.offset ) def as_constraint(self) -> Constraints: + """ + Converts the distribution to a set of constraints. + """ return Constraints( rules=[ (self.name, "le", self.high), @@ -429,16 +766,79 @@ def as_constraint(self) -> Constraints: ] ) - def min(self) -> Any: + def min(self) -> Optional[datetime]: + """ + Returns the minimum datetime value of the distribution. + """ return self.low - def max(self) -> Any: + def max(self) -> Optional[datetime]: + """ + Returns the maximum datetime value of the distribution. + """ return self.high def dtype(self) -> str: + """ + Returns the data type of the distribution. + """ return "datetime" +class PassThroughDistribution(Distribution): + """ + .. inheritance-diagram:: synthcity.plugins.core.distribution.PassThroughDistribution + :parts: 1 + """ + + data: pd.Series + _dtype: str = PrivateAttr("") + + def setup_distribution(self) -> None: + if self.data is None: + raise ValueError("'data' must be provided for PassThroughDistribution.") + + # No additional attributes to set up since 'data' is used directly + # Optionally, store the data type for dtype method + self._dtype = str(self.data.dtype) + + def sample(self, count: int = 1) -> Any: + msamples = self.sample_marginal(count) + if msamples is not None: + return msamples + return self.data.sample( + n=count, replace=True, random_state=self.random_state + ).values + + def as_constraint(self) -> Constraints: + # No constraints needed for pass-through columns + return Constraints(rules=[]) + + def get(self) -> List[Any]: + # Return the unique values or any relevant info + return [self.name] + + def has(self, val: Any) -> bool: + # Check if the value exists in the data + return val in self.data.values + + def includes(self, other: "Distribution") -> bool: + # Since we are passing through values, we can define includes as checking if all values in other are in self.data + if isinstance(other, PassThroughDistribution): + return set(other.data.unique()).issubset(set(self.data.unique())) + else: + return False + + def min(self) -> Any: + return self.data.min() + + def max(self) -> Any: + return self.data.max() + + def dtype(self) -> str: + return str(self.data.dtype) + + def constraint_to_distribution(constraints: Constraints, feature: str) -> Distribution: """Infer Distribution from Constraints. diff --git a/src/synthcity/plugins/core/models/tabular_encoder.py b/src/synthcity/plugins/core/models/tabular_encoder.py index 1e6f9fec..364946ea 100644 --- a/src/synthcity/plugins/core/models/tabular_encoder.py +++ b/src/synthcity/plugins/core/models/tabular_encoder.py @@ -7,7 +7,7 @@ # third party import numpy as np import pandas as pd -from pydantic import BaseModel, validate_arguments, validator +from pydantic import BaseModel, field_validator, validate_arguments from sklearn.base import BaseEstimator, TransformerMixin from sklearn.preprocessing import MinMaxScaler @@ -23,18 +23,20 @@ class FeatureInfo(BaseModel): name: str feature_type: str - transform: Any + transform: Any = None output_dimensions: int transformed_features: List[str] trans_feature_types: List[str] - @validator("feature_type") + @field_validator("feature_type") + @classmethod def _feature_type_validator(cls: Any, v: str) -> str: if v not in ["discrete", "continuous"]: raise ValueError(f"Invalid feature type {v}") return v - @validator("transform") + @field_validator("transform") + @classmethod def _transform_validator(cls: Any, v: Any) -> Any: if not ( hasattr(v, "fit") @@ -44,7 +46,8 @@ def _transform_validator(cls: Any, v: Any) -> Any: raise ValueError(f"Invalid transform {v}") return v - @validator("output_dimensions") + @field_validator("output_dimensions") + @classmethod def _output_dimensions_validator(cls: Any, v: int) -> int: if v <= 0: raise ValueError(f"Invalid output_dimensions {v}") diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py index b4cd0dff..1b3d9020 100644 --- a/src/synthcity/plugins/core/plugin.py +++ b/src/synthcity/plugins/core/plugin.py @@ -9,7 +9,7 @@ # third party import pandas as pd -from pydantic import validate_arguments +from pydantic import ConfigDict, validate_arguments # synthcity absolute import synthcity.logger as log @@ -71,9 +71,7 @@ class Plugin(Serializable, metaclass=ABCMeta): Internal parameter for schema. marginal or uniform. """ - class Config: - arbitrary_types_allowed = True - validate_assignment = True + model_config = ConfigDict(arbitrary_types_allowed=True, validate_assignment=True) def __init__( self, @@ -407,6 +405,12 @@ def _safe_generate( iter_samples, columns=self.training_schema().features() ) + # Handle protected columns + for col in syn_schema.protected_cols: + if col not in iter_samples_df.columns: + # Sample the protected column using its distribution + iter_samples_df[col] = syn_schema.domain[col].sample(count) + # validate schema iter_samples_df = self.training_schema().adapt_dtypes(iter_samples_df) diff --git a/src/synthcity/plugins/core/schema.py b/src/synthcity/plugins/core/schema.py index 29e27a97..e44f0dc8 100644 --- a/src/synthcity/plugins/core/schema.py +++ b/src/synthcity/plugins/core/schema.py @@ -1,30 +1,33 @@ # stdlib -from typing import Any, Dict, Generator, List +from typing import Any, Dict, Generator, List, Optional, Union # third party -import numpy as np import pandas as pd -from pydantic import BaseModel, validate_arguments, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + validate_arguments, +) # synthcity absolute +import synthcity.logger as log from synthcity.plugins.core.constraints import Constraints -from synthcity.plugins.core.dataloader import DataLoader +from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader from synthcity.plugins.core.distribution import ( CategoricalDistribution, DatetimeDistribution, Distribution, FloatDistribution, IntegerDistribution, - constraint_to_distribution, + PassThroughDistribution, ) class Schema(BaseModel): """ - .. inheritance-diagram:: synthcity.plugins.core.schema.Schema - :parts: 1 - - Utility class for defining the schema of a Dataset. Constructor Args: @@ -40,90 +43,42 @@ class Schema(BaseModel): (Optional) the data set """ - sampling_strategy: str = "marginal" # uniform or marginal - protected_cols: List[str] = ["seq_id"] - random_state: int = 0 - data: Any = None - domain: Dict = {} - - @validator("domain", always=True) - def _validate_domain(cls: Any, v: Any, values: Dict) -> Dict: - if "data" not in values or values["data"] is None: - return v - - feature_domain = {} - raw = values["data"] - - if isinstance(raw, DataLoader): - X = raw.dataframe() - elif isinstance(raw, pd.DataFrame): - X = raw - else: - raise ValueError("You need to provide a DataLoader in the data argument") - - if X.shape[1] == 0 or X.shape[0] == 0: - return v - - sampling_strategy = values["sampling_strategy"] - random_state = values["random_state"] - - if sampling_strategy == "marginal": - for col in X.columns: - if X[col].dtype.kind in ["O", "b"] or len(X[col].unique()) < 10: - feature_domain[col] = CategoricalDistribution( - name=col, data=X[col], random_state=random_state - ) - elif X[col].dtype.kind in ["i", "u"]: - feature_domain[col] = IntegerDistribution( - name=col, data=X[col], random_state=random_state - ) - elif X[col].dtype.kind == "f": - feature_domain[col] = FloatDistribution( - name=col, data=X[col], random_state=random_state - ) - elif X[col].dtype.kind == "M": - feature_domain[col] = DatetimeDistribution( - name=col, data=X[col], random_state=random_state - ) - else: - raise ValueError("unsupported format ", col) - elif sampling_strategy == "uniform": - for col in X.columns: - if X[col].dtype.kind in ["O", "b"] or len(X[col].unique()) < 10: - feature_domain[col] = CategoricalDistribution( - name=col, - choices=list(X[col].unique()), - random_state=random_state, - ) - elif X[col].dtype.kind in ["i", "u"]: - feature_domain[col] = IntegerDistribution( - name=col, - low=X[col].min(), - high=X[col].max(), - random_state=random_state, - ) - elif X[col].dtype.kind == "f": - feature_domain[col] = FloatDistribution( - name=col, - low=X[col].min(), - high=X[col].max(), - random_state=random_state, - ) - elif X[col].dtype.kind == "M": - feature_domain[col] = DatetimeDistribution( - name=col, - low=X[col].min(), - high=X[col].max(), - random_state=random_state, - ) - else: - raise ValueError("unsupported format ", col) - else: - raise ValueError(f"invalid sampling strategy {sampling_strategy}") - - del values["data"] - - return feature_domain + sampling_strategy: str = Field(default="marginal") + protected_cols: List[str] = [] + random_state: int = Field(default=0) + domain: Dict = Field(default_factory=dict) + + data: Optional[Union[DataLoader, pd.DataFrame]] = Field(default=None, exclude=True) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("data", mode="before") + def validate_data(cls, v: Any) -> Optional[DataLoader]: + if v is not None: + if isinstance(v, pd.DataFrame): + return GenericDataLoader(v) + elif isinstance(v, DataLoader): + return v + else: + raise ValueError( + f"Invalid data type for 'data': {type(v)}. Expected DataLoader or pandas DataFrame." + ) + return v + + @model_validator(mode="after") + def initialize_domain(cls, model: "Schema") -> "Schema": + if model.data is not None: + X = model.data.dataframe() + model.domain = model._infer_domain( + X, + sampling_strategy=model.sampling_strategy, + random_state=model.random_state, + ) + # Remove 'data' attribute from the model + del model.__dict__["data"] + if "data" in model.__fields_set__: + model.__fields_set__.remove("data") + return model @validate_arguments def get(self, feature: str) -> Distribution: @@ -178,14 +133,11 @@ def features(self) -> List: return list(self.domain.keys()) def sample(self, count: int) -> pd.DataFrame: - samples = pd.DataFrame( - np.zeros((count, len(self.features()))), columns=self.features() - ) - - for feature in self.features(): - samples[feature] = self.domain[feature].sample(count) - - return samples + data = {} + for col, dist in self.domain.items(): + samples = dist.sample(count) + data[col] = samples + return pd.DataFrame(data) def adapt_dtypes(self, X: pd.DataFrame) -> pd.DataFrame: """Applying the data type to a new data frame @@ -208,24 +160,200 @@ def adapt_dtypes(self, X: pd.DataFrame) -> pd.DataFrame: return X def as_constraints(self) -> Constraints: - """Convert the schema to a list of Constraints.""" - constraints = Constraints(rules=[]) - for feature in self: - if feature in self.protected_cols: - continue - constraints.extend(self[feature].as_constraint()) - - return constraints + rules = [] + for feature, dist in self.domain.items(): + rules.extend(dist.as_constraint().rules) + return Constraints(rules=rules) @classmethod def from_constraints(cls, constraints: Constraints) -> "Schema": - """Create a schema from a list of Constraints.""" - - features = constraints.features() - feature_domain: dict = {} - - for feature in features: - dist = constraint_to_distribution(constraints, feature) - feature_domain[feature] = dist - - return cls(domain=feature_domain) + domain: Dict = {} + feature_params: Dict = {} + + # Collect constraint information + for feature, op, value in constraints.rules: + if feature not in feature_params: + feature_params[feature] = { + "name": feature, + "random_state": None, + "low": None, + "high": None, + "dtype": "float", # Default to 'float' if not specified + "choices": [], + } + + params = feature_params[feature] + + if op in ["ge", ">="]: + if params["low"] is None or value > params["low"]: + params["low"] = value + elif op in ["le", "<="]: + if params["high"] is None or value < params["high"]: + params["high"] = value + elif op in ["eq", "=="]: + # For '==', set both 'low' and 'high' to value + params["low"] = value + params["high"] = value + elif op in ["in", "isin"]: + if isinstance(value, list): + params["choices"].extend(value) + else: + params["choices"].append(value) + elif op == "dtype": + params["dtype"] = value + else: + # Handle other operators if necessary + pass + + # Create distribution objects + for feature, params in feature_params.items(): + dtype = params["dtype"] + if dtype == "float": + if params["low"] is None or params["high"] is None: + raise ValueError( + f"Cannot create FloatDistribution for '{feature}' without 'low' and 'high' values." + ) + domain[feature] = FloatDistribution( + name=params["name"], + random_state=params["random_state"], + low=params["low"], + high=params["high"], + ) + elif dtype == "int": + if params["low"] is None or params["high"] is None: + raise ValueError( + f"Cannot create IntegerDistribution for '{feature}' without 'low' and 'high' values." + ) + domain[feature] = IntegerDistribution( + name=params["name"], + random_state=params["random_state"], + low=int(params["low"]), + high=int(params["high"]), + step=1, # Default step to 1 or adjust as needed + ) + elif dtype in ["category", "object"]: + choices = params.get("choices") + if choices is None or not choices: + raise ValueError( + f"Cannot create CategoricalDistribution for '{feature}' without 'choices'." + ) + domain[feature] = CategoricalDistribution( + name=params["name"], + random_state=params["random_state"], + choices=list(set(choices)), + ) + else: + raise ValueError( + f"Unsupported dtype '{dtype}' for feature '{feature}'." + ) + + return cls(domain=domain) + + def _infer_domain( + self, + X: pd.DataFrame, + sampling_strategy: str, + random_state: int, + ) -> Dict[str, Distribution]: + feature_domain: Dict[str, Distribution] = {} + + for idx, col in enumerate(X.columns): + col_random_state = random_state + idx + 1 # Ensure unique seeds + + try: + if sampling_strategy == "marginal": + if col in self.protected_cols: + feature_domain[col] = PassThroughDistribution( + name=col, + data=X[col], + random_state=col_random_state, + ) + continue + + is_categorical = pd.api.types.is_categorical_dtype(X[col]) + is_object = X[col].dtype == object + is_bool = pd.api.types.is_bool_dtype(X[col]) + is_integer = pd.api.types.is_integer_dtype(X[col]) + is_float = pd.api.types.is_float_dtype(X[col]) + is_datetime = pd.api.types.is_datetime64_any_dtype(X[col]) + + if is_categorical or is_object or is_bool: + feature_domain[col] = CategoricalDistribution( + name=col, + data=X[col], + random_state=col_random_state, + ) + elif is_integer: + feature_domain[col] = IntegerDistribution( + name=col, + data=X[col], + random_state=col_random_state, + ) + elif is_float: + feature_domain[col] = FloatDistribution( + name=col, + data=X[col], + random_state=col_random_state, + ) + elif is_datetime: + feature_domain[col] = DatetimeDistribution( + name=col, + data=X[col], + random_state=col_random_state, + ) + else: + raise ValueError( + f"Unsupported data type for column '{col}' with dtype {X[col].dtype}" + ) + elif sampling_strategy == "uniform": + + is_categorical = pd.api.types.is_categorical_dtype(X[col]) + is_object = X[col].dtype == object + is_bool = pd.api.types.is_bool_dtype(X[col]) + is_integer = pd.api.types.is_integer_dtype(X[col]) + is_float = pd.api.types.is_float_dtype(X[col]) + is_datetime = pd.api.types.is_datetime64_any_dtype(X[col]) + + if ( + pd.api.types.is_categorical_dtype(X[col]) + or X[col].dtype == object + or pd.api.types.is_bool_dtype(X[col]) + ): + feature_domain[col] = CategoricalDistribution( + name=col, + choices=list(X[col].unique()), + random_state=col_random_state, + sampling_strategy=sampling_strategy, + ) + elif pd.api.types.is_integer_dtype(X[col]): + feature_domain[col] = IntegerDistribution( + name=col, + low=X[col].min(), + high=X[col].max(), + random_state=col_random_state, + sampling_strategy=sampling_strategy, + ) + elif pd.api.types.is_float_dtype(X[col]): + feature_domain[col] = FloatDistribution( + name=col, + low=X[col].min(), + high=X[col].max(), + random_state=col_random_state, + sampling_strategy=sampling_strategy, + ) + elif pd.api.types.is_datetime64_any_dtype(X[col]): + feature_domain[col] = DatetimeDistribution( + name=col, + low=X[col].min(), + high=X[col].max(), + random_state=col_random_state, + sampling_strategy=sampling_strategy, + ) + else: + raise ValueError( + f"Unsupported sampling strategy '{sampling_strategy}'" + ) + except Exception as e: + log.error(f"Exception occurred while processing column '{col}': {e}") + raise + return feature_domain diff --git a/tests/benchmarks/test_benchmarks.py b/tests/benchmarks/test_benchmarks.py index 0a49b100..6baa305f 100644 --- a/tests/benchmarks/test_benchmarks.py +++ b/tests/benchmarks/test_benchmarks.py @@ -93,21 +93,20 @@ def test_benchmark_invalid_metric() -> None: def test_benchmark_custom_target() -> None: - X, y = load_iris(return_X_y=True, as_frame=True) + X, y = load_diabetes(return_X_y=True, as_frame=True) X["target"] = y Benchmarks.evaluate( [ - ("test2", "uniform_sampler", {}), + ("test2", "ctgan", {}), ], - GenericDataLoader( - X, sensitive_columns=["sex"], target_column="sepal width (cm)" - ), + GenericDataLoader(X, target_column="target"), metrics={ "performance": [ "linear_model", ] }, + task_type="regression", ) diff --git a/tests/metrics/test_api.py b/tests/metrics/test_api.py index b34396d1..3b9edfcf 100644 --- a/tests/metrics/test_api.py +++ b/tests/metrics/test_api.py @@ -74,6 +74,14 @@ def test_metric_filter(metric_filter: dict) -> None: model.fit(Xraw) X_gen = model.generate(100) + assert not X_gen.dataframe().empty + print(X_gen) + + # Add debugging here + print(f"Metrics to evaluate: {metric_filter}") + print( + f"Xraw shape: {Xraw.dataframe().shape}, X_gen shape: {X_gen.dataframe().shape}" + ) out = Metrics.evaluate( Xraw, @@ -81,6 +89,8 @@ def test_metric_filter(metric_filter: dict) -> None: metrics=metric_filter, ) + print(f"Output of Metrics.evaluate: {out}") + expected_index = [ f"{category}.{metric}.score" for category in metric_filter diff --git a/tests/nb_eval.py b/tests/nb_eval.py index f2f4bb75..d119fbf2 100644 --- a/tests/nb_eval.py +++ b/tests/nb_eval.py @@ -12,11 +12,11 @@ workspace.mkdir(parents=True, exist_ok=True) -def run_notebook(notebook_path: Path) -> None: +def run_notebook(notebook_path: Path, timeout: int) -> None: with open(notebook_path) as f: nb = nbformat.read(f, as_version=4) - proc = ExecutePreprocessor(timeout=1800) + proc = ExecutePreprocessor(timeout=timeout) # Will raise on cell error proc.preprocess(nb, {"metadata": {"path": workspace}}) @@ -29,22 +29,6 @@ def run_notebook(notebook_path: Path) -> None: except ImportError: goggle_disabled = True -try: - # synthcity absolute - from synthcity.plugins.core.models.tabular_arf import TabularARF # noqa: F401 - - arf_disabled = False -except ImportError: - arf_disabled = True - -try: - # synthcity absolute - from synthcity.plugins.core.models.tabular_great import TabularGReaT # noqa: F401 - - great_disabled = False -except ImportError: - great_disabled = True - all_tests = [ "basic_examples", "benchmarks", @@ -56,8 +40,8 @@ def run_notebook(notebook_path: Path) -> None: "plugin_ctgan", "plugin_nflow", "plugin_tvae", - "plugin_timegan", - "plugin_radialgan" "plugin_arf", + "plugin_radialgan", + "plugin_arf", "plugin_bayesian_network", "plugin_ddpm", "plugin_dummy_sampler", @@ -73,14 +57,12 @@ def run_notebook(notebook_path: Path) -> None: "plugin_fourier_flows", "plugin_timegan", "plugin_aim", + "plugin_arf", + "plugin_great", ] if not goggle_disabled: all_tests.append("plugin_goggle") -if not arf_disabled: - all_tests.append("plugin_arf") -if not great_disabled: - all_tests.append("plugin_great") minimal_tests = [ "basic_examples", @@ -88,13 +70,10 @@ def run_notebook(notebook_path: Path) -> None: "plugin_ctgan", "plugin_nflow", "plugin_tvae", - "plugin_timegan", ] # For extras goggle_tests = ["plugin_goggle"] -arf_tests = ["plugin_arf"] -great_tests = ["plugin_great"] @click.command() @@ -102,12 +81,18 @@ def run_notebook(notebook_path: Path) -> None: @click.option( "--tutorial_tests", type=click.Choice( - ["minimal_tests", "all_tests", "goggle_tests", "plugin_arf", "plugin_great"], + ["minimal_tests", "all_tests", "goggle_tests"], case_sensitive=False, ), default="minimal_tests", ) -def main(nb_dir: Path, tutorial_tests: str) -> None: +@click.option( + "--timeout", + type=int, + default=1800, + help="Timeout for notebook execution in seconds.", +) +def main(nb_dir: Path, tutorial_tests: str, timeout: int) -> None: nb_dir = Path(nb_dir) enabled_tests: List = [] if tutorial_tests == "all_tests": @@ -134,7 +119,7 @@ def main(nb_dir: Path, tutorial_tests: str) -> None: print("Testing ", p.name) start = time() try: - run_notebook(p) + run_notebook(p, timeout) except BaseException as e: print("FAIL", p.name, e) diff --git a/tests/plugins/core/test_distribution.py b/tests/plugins/core/test_distribution.py index 524b361f..34c55313 100644 --- a/tests/plugins/core/test_distribution.py +++ b/tests/plugins/core/test_distribution.py @@ -18,6 +18,8 @@ def test_categorical() -> None: param = CategoricalDistribution(name="test", choices=["1", "2", "55", "sdfsf"]) + assert param.marginal_distribution is not None + assert param.get() == ["test", ["1", "2", "55", "sdfsf"]] assert len(param.sample(count=5)) == 5 for sample in param.sample(count=5): @@ -47,8 +49,15 @@ def test_categorical() -> None: assert param.includes(param_other) assert param_other.includes(param) - assert param.marginal_distribution is None - assert param.dtype() == "object" + # Instead of asserting marginal_distribution is None, assert it's correctly initialized + expected_marginal = pd.Series( + [0.25, 0.25, 0.25, 0.25], index=["1", "2", "55", "sdfsf"] + ) + pd.testing.assert_series_equal( + param.marginal_distribution.sort_index(), + expected_marginal.sort_index(), + check_names=False, + ) def test_categorical_from_data() -> None: @@ -119,7 +128,7 @@ def test_integer_from_data() -> None: assert param.get() == ["test", 1, 88, 1] assert len(param.sample(count=5)) == 5 for sample in param.sample(count=5): - assert sample in list(range(0, 101)) + assert sample in list(range(1, 89)) assert param.has(1) assert not param.has(101) assert not param.has(-1) @@ -130,7 +139,6 @@ def test_integer_from_data() -> None: assert not param_other.includes(param) assert param.marginal_distribution is not None - assert set(param.marginal_distribution.keys()) == set([1, 2, 4, 12, 88]) def test_float() -> None: @@ -176,7 +184,7 @@ def test_float_from_data() -> None: data=pd.Series([0, 1.1, 2.3, 1, 0.5, 1, 1, 1, 1, 1, 1]), ) - assert param.get() == ["test", 0, 2.3] + assert param.get() == ["test", 0.0, 2.3] assert len(param.sample(count=5)) == 5 for sample in param.sample(count=5): assert sample <= 2.3 @@ -187,8 +195,8 @@ def test_float_from_data() -> None: assert param.includes(param_other) assert not param_other.includes(param) + # This assertion should now pass assert param.marginal_distribution is not None - assert set(param.marginal_distribution.keys()) == set([0, 1.1, 2.3, 1.0, 0.5]) def test_categorical_constraint_to_distribution() -> None: diff --git a/tests/plugins/core/test_schema.py b/tests/plugins/core/test_schema.py index b004b37f..ad24c251 100644 --- a/tests/plugins/core/test_schema.py +++ b/tests/plugins/core/test_schema.py @@ -10,12 +10,8 @@ def test_schema_fail() -> None: - if pydantic.__version__ < "2": - with pytest.raises(pydantic.error_wrappers.ValidationError): - Schema(data="sdfsfs") - else: - with pytest.raises(pydantic.pydantic_core._pydantic_core.ValidationError): - Schema(data="sdfsfs") + with pytest.raises(pydantic.ValidationError): + Schema(data="sdfsfs") def test_schema_ok() -> None: @@ -67,9 +63,34 @@ def test_schema_as_constraint() -> None: cons = schema.as_constraints() - assert len(cons) == 7 - for rule in cons: - assert rule[1] == "in" + # Old assertions + # assert len(cons) == 7 + # for rule in cons: + # assert rule[1] == "in" + + # New assertions + assert len(cons) == 15 + + # Optionally, verify that the constraints are as expected + expected_constraints = [ + ("a", "in", ["a", "b", "c"]), + ("b", "in", [True, False]), + ("c", "ge", 1), + ("c", "le", 3), + ("c", "dtype", "int"), + ("d", "ge", 4.0), + ("d", "le", 6.0), + ("d", "dtype", "float"), + ("e", "ge", 7), + ("e", "le", 9), + ("e", "dtype", "int"), + ("f", "in", ["odd", "even"]), + ("g", "ge", pd.Timestamp("2023-01-01")), + ("g", "le", pd.Timestamp("2023-01-03")), + ("g", "dtype", "datetime"), + ] + + assert sorted(cons.rules) == sorted(expected_constraints) def test_schema_from_constraint() -> None: diff --git a/tutorials/plugins/time_series/plugin_timegan.ipynb b/tutorials/plugins/time_series/plugin_timegan.ipynb index eafb5b65..2a316bf1 100644 --- a/tutorials/plugins/time_series/plugin_timegan.ipynb +++ b/tutorials/plugins/time_series/plugin_timegan.ipynb @@ -111,7 +111,7 @@ "# third party\n", "import matplotlib.pyplot as plt\n", "\n", - "syn_model.plot(plt, loader, count=1000, plots=[\"tsne\"])\n", + "syn_model.plot(plt, loader, count=100, plots=[\"tsne\"])\n", "\n", "plt.show()" ]