Skip to content

Commit

Permalink
Bump black from 22.8.0 to 24.3.0
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 6c8901c03ec9141224d62ec83e8578a94c5e72fd
  • Loading branch information
dependabot[bot] authored and theonlyrob committed Apr 2, 2024
1 parent e110d64 commit c9d3982
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 78 deletions.
6 changes: 2 additions & 4 deletions src/gretel_trainer/benchmark/custom/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@


class CustomModel(Protocol):
def train(self, source: Dataset, **kwargs) -> None:
...
def train(self, source: Dataset, **kwargs) -> None: ...

def generate(self, **kwargs) -> pd.DataFrame:
...
def generate(self, **kwargs) -> pd.DataFrame: ...
21 changes: 7 additions & 14 deletions src/gretel_trainer/benchmark/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,20 @@ def cannot_proceed(self) -> bool:

class Strategy(Protocol):
@property
def dataset(self) -> Dataset:
...
def dataset(self) -> Dataset: ...

@property
def evaluate_ref_data(self) -> str:
...
def evaluate_ref_data(self) -> str: ...

def runnable(self) -> bool:
...
def runnable(self) -> bool: ...

def train(self) -> None:
...
def train(self) -> None: ...

def generate(self) -> None:
...
def generate(self) -> None: ...

def get_train_time(self) -> Optional[float]:
...
def get_train_time(self) -> Optional[float]: ...

def get_generate_time(self) -> Optional[float]:
...
def get_generate_time(self) -> Optional[float]: ...


class Executor:
Expand Down
1 change: 1 addition & 0 deletions src/gretel_trainer/relational/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
which you can then use with the "MultiTable" class to process data with
Gretel Transforms, Classify, Synthetics, or a combination of both.
"""

from __future__ import annotations

import logging
Expand Down
6 changes: 2 additions & 4 deletions src/gretel_trainer/relational/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,13 +179,11 @@ def get_foreign_keys(
) -> list: # can't specify element type (ForeignKey) without cyclic dependency
...

def get_table_columns(self, table: str) -> list[str]:
...
def get_table_columns(self, table: str) -> list[str]: ...

def get_invented_table_metadata(
self, table: str
) -> Optional[InventedTableMetadata]:
...
) -> Optional[InventedTableMetadata]: ...


@dataclass
Expand Down
8 changes: 5 additions & 3 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
take extracted data from a database or data warehouse, and process it
with Gretel using Transforms, Classify, and Synthetics.
"""

from __future__ import annotations

import json
Expand Down Expand Up @@ -574,9 +575,10 @@ def run_transforms(
if isinstance(data_source, pd.DataFrame):
data_source.to_csv(transforms_run_path, index=False)
else:
with open_artifact(data_source, "rb") as src, open_artifact(
transforms_run_path, "wb"
) as dest:
with (
open_artifact(data_source, "rb") as src,
open_artifact(transforms_run_path, "wb") as dest,
):
shutil.copyfileobj(src, dest)
transforms_run_paths[table] = transforms_run_path

Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/relational/sdk_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,10 @@ def download_file_artifact(
out_path: Union[str, Path],
) -> bool:
try:
with gretel_object.get_artifact_handle(artifact_name) as src, open_artifact(
out_path, "wb"
) as dest:
with (
gretel_object.get_artifact_handle(artifact_name) as src,
open_artifact(out_path, "wb") as dest,
):
shutil.copyfileobj(src, dest)
return True
except:
Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ def prepare_training_data(
continue

source_path = rel_data.get_table_source(table)
with open_artifact(source_path, "rb") as src, open_artifact(
path, "wb"
) as dest:
with (
open_artifact(source_path, "rb") as src,
open_artifact(path, "wb") as dest,
):
pd.DataFrame(columns=use_columns).to_csv(dest, index=False)
for chunk in pd.read_csv(src, usecols=use_columns, chunksize=10_000):
chunk.to_csv(dest, index=False, mode="a", header=False)
Expand Down
6 changes: 2 additions & 4 deletions src/gretel_trainer/relational/table_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,12 @@ def is_complete(self) -> bool:
@overload
def _field_from_json(
self, report_json: Optional[dict], entry: str, field: Literal["score"]
) -> Optional[int]:
...
) -> Optional[int]: ...

@overload
def _field_from_json(
self, report_json: Optional[dict], entry: str, field: Literal["grade"]
) -> Optional[str]:
...
) -> Optional[str]: ...

def _field_from_json(
self, report_json: Optional[dict], entry: str, field: str
Expand Down
33 changes: 11 additions & 22 deletions src/gretel_trainer/relational/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,28 @@ def maybe_start_job(self, job: Job, table_name: str, action: str) -> None:

class Task(Protocol):
@property
def ctx(self) -> TaskContext:
...
def ctx(self) -> TaskContext: ...

def action(self, job: Job) -> str:
...
def action(self, job: Job) -> str: ...

@property
def table_collection(self) -> list[str]:
...
def table_collection(self) -> list[str]: ...

def more_to_do(self) -> bool:
...
def more_to_do(self) -> bool: ...

def is_finished(self, table: str) -> bool:
...
def is_finished(self, table: str) -> bool: ...

def get_job(self, table: str) -> Job:
...
def get_job(self, table: str) -> Job: ...

def handle_completed(self, table: str, job: Job) -> None:
...
def handle_completed(self, table: str, job: Job) -> None: ...

def handle_failed(self, table: str, job: Job) -> None:
...
def handle_failed(self, table: str, job: Job) -> None: ...

def handle_in_progress(self, table: str, job: Job) -> None:
...
def handle_in_progress(self, table: str, job: Job) -> None: ...

def handle_lost_contact(self, table: str, job: Job) -> None:
...
def handle_lost_contact(self, table: str, job: Job) -> None: ...

def each_iteration(self) -> None:
...
def each_iteration(self) -> None: ...


def run_task(task: Task, extended_sdk: ExtendedGretelSDK) -> None:
Expand Down
7 changes: 4 additions & 3 deletions src/gretel_trainer/relational/tasks/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def _write_results(self, job: Job, table: str) -> None:

destpath = self.output_handler.filepath_for(filename)

with job.get_artifact_handle(artifact_name) as src, open_artifact(
str(destpath), "wb"
) as dest:
with (
job.get_artifact_handle(artifact_name) as src,
open_artifact(str(destpath), "wb") as dest,
):
shutil.copyfileobj(src, dest)
self.result_filepaths[table] = destpath
1 change: 1 addition & 0 deletions src/gretel_trainer/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
}
}
"""

from __future__ import annotations

import json
Expand Down
9 changes: 4 additions & 5 deletions tests/relational/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,10 @@ def output_handler(tmpdir, project):

@pytest.fixture()
def project():
with patch(
"gretel_trainer.relational.multi_table.create_project"
) as create_project, patch(
"gretel_trainer.relational.multi_table.get_project"
) as get_project:
with (
patch("gretel_trainer.relational.multi_table.create_project") as create_project,
patch("gretel_trainer.relational.multi_table.get_project") as get_project,
):
project = Mock()
project.name = "name"
project.display_name = "display_name"
Expand Down
47 changes: 38 additions & 9 deletions tests/relational/test_ancestral_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets):

strategy = AncestralStrategy()

with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
with (
tempfile.NamedTemporaryFile() as pets_dest,
tempfile.NamedTemporaryFile() as humans_dest,
):
strategy.prepare_training_data(
pets, {"pets": pets_dest.name, "humans": humans_dest.name}
)
Expand All @@ -32,7 +35,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets):
def test_prepare_training_data_subset_of_tables(pets):
strategy = AncestralStrategy()

with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
with (
tempfile.NamedTemporaryFile() as pets_dest,
tempfile.NamedTemporaryFile() as humans_dest,
):
# We aren't synthesizing the "humans" table, so it is not in this list argument...
training_data = strategy.prepare_training_data(pets, {"pets": pets_dest.name})

Expand All @@ -53,7 +59,10 @@ def test_prepare_training_data_subset_of_tables(pets):
def test_prepare_training_data_returns_multigenerational_data(pets):
strategy = AncestralStrategy()

with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
with (
tempfile.NamedTemporaryFile() as pets_dest,
tempfile.NamedTemporaryFile() as humans_dest,
):
training_data = strategy.prepare_training_data(
pets, {"pets": pets_dest.name, "humans": humans_dest.name}
)
Expand All @@ -64,7 +73,10 @@ def test_prepare_training_data_returns_multigenerational_data(pets):


def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(art):
with tempfile.NamedTemporaryFile() as artists_csv, tempfile.NamedTemporaryFile() as paintings_csv:
with (
tempfile.NamedTemporaryFile() as artists_csv,
tempfile.NamedTemporaryFile() as paintings_csv,
):
pd.DataFrame(
data={
"id": [f"A{i}" for i in range(100)],
Expand All @@ -84,7 +96,10 @@ def test_prepare_training_data_drops_highly_unique_categorical_ancestor_fields(a

strategy = AncestralStrategy()

with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest:
with (
tempfile.NamedTemporaryFile() as artists_dest,
tempfile.NamedTemporaryFile() as paintings_dest,
):
training_data = strategy.prepare_training_data(
art,
{
Expand All @@ -111,7 +126,10 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art):
else:
highly_nan_names.append("some name")

with tempfile.NamedTemporaryFile() as artists_csv, tempfile.NamedTemporaryFile() as paintings_csv:
with (
tempfile.NamedTemporaryFile() as artists_csv,
tempfile.NamedTemporaryFile() as paintings_csv,
):
pd.DataFrame(
data={
"id": [f"A{i}" for i in range(100)],
Expand All @@ -131,7 +149,10 @@ def test_prepare_training_data_drops_highly_nan_ancestor_fields(art):

strategy = AncestralStrategy()

with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest:
with (
tempfile.NamedTemporaryFile() as artists_dest,
tempfile.NamedTemporaryFile() as paintings_dest,
):
training_data = strategy.prepare_training_data(
art,
{
Expand All @@ -155,7 +176,10 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec
):
strategy = AncestralStrategy()

with tempfile.NamedTemporaryFile() as artists_dest, tempfile.NamedTemporaryFile() as paintings_dest:
with (
tempfile.NamedTemporaryFile() as artists_dest,
tempfile.NamedTemporaryFile() as paintings_dest,
):
training_data = strategy.prepare_training_data(
art,
{
Expand Down Expand Up @@ -191,7 +215,12 @@ def test_prepare_training_data_translates_alphanumeric_keys_and_adds_min_max_rec

def test_prepare_training_data_with_composite_keys(tpch):
strategy = AncestralStrategy()
with tempfile.NamedTemporaryFile() as supplier_dest, tempfile.NamedTemporaryFile() as part_dest, tempfile.NamedTemporaryFile() as partsupp_dest, tempfile.NamedTemporaryFile() as lineitem_dest:
with (
tempfile.NamedTemporaryFile() as supplier_dest,
tempfile.NamedTemporaryFile() as part_dest,
tempfile.NamedTemporaryFile() as partsupp_dest,
tempfile.NamedTemporaryFile() as lineitem_dest,
):
training_data = strategy.prepare_training_data(
tpch,
{
Expand Down
20 changes: 16 additions & 4 deletions tests/relational/test_independent_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets):

strategy = IndependentStrategy()

with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
with (
tempfile.NamedTemporaryFile() as pets_dest,
tempfile.NamedTemporaryFile() as humans_dest,
):
strategy.prepare_training_data(
pets, {"pets": pets_dest.name, "humans": humans_dest.name}
)
Expand All @@ -30,7 +33,10 @@ def test_preparing_training_data_does_not_mutate_source_data(pets):
def test_prepare_training_data_removes_primary_and_foreign_keys(pets):
strategy = IndependentStrategy()

with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
with (
tempfile.NamedTemporaryFile() as pets_dest,
tempfile.NamedTemporaryFile() as humans_dest,
):
training_data = strategy.prepare_training_data(
pets, {"pets": pets_dest.name, "humans": humans_dest.name}
)
Expand All @@ -42,7 +48,10 @@ def test_prepare_training_data_removes_primary_and_foreign_keys(pets):
def test_prepare_training_data_subset_of_tables(pets):
strategy = IndependentStrategy()

with tempfile.NamedTemporaryFile() as pets_dest, tempfile.NamedTemporaryFile() as humans_dest:
with (
tempfile.NamedTemporaryFile() as pets_dest,
tempfile.NamedTemporaryFile() as humans_dest,
):
training_data = strategy.prepare_training_data(
pets, {"humans": humans_dest.name}
)
Expand All @@ -53,7 +62,10 @@ def test_prepare_training_data_subset_of_tables(pets):
def test_prepare_training_data_join_table(insurance):
strategy = IndependentStrategy()

with tempfile.NamedTemporaryFile() as beneficiary_dest, tempfile.NamedTemporaryFile() as policies_dest:
with (
tempfile.NamedTemporaryFile() as beneficiary_dest,
tempfile.NamedTemporaryFile() as policies_dest,
):
training_data = strategy.prepare_training_data(
insurance,
{
Expand Down

0 comments on commit c9d3982

Please sign in to comment.