Skip to content

Commit

Permalink
PLAT 1424: Guard against foreign key cycles
Browse files Browse the repository at this point in the history
GitOrigin-RevId: e2950807af423ca51e71d261e4a63f9d64328ec9
  • Loading branch information
mikeknep committed Dec 20, 2023
1 parent 32ac0e7 commit 35684de
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import pandas as pd
import smart_open

from networkx.algorithms.cycles import simple_cycles
from networkx.algorithms.dag import dag_longest_path_length, topological_sort
from networkx.classes.function import number_of_edges
from pandas.api.types import is_string_dtype
Expand Down Expand Up @@ -202,6 +203,13 @@ def is_empty(self) -> bool:
"""
return not self.graph.number_of_nodes() > 0

@property
def foreign_key_cycles(self) -> list[list[str]]:
"""
Returns lists of tables that have cyclic foreign key relationships.
"""
return list(simple_cycles(self.graph))

def restore(self, tableset: dict[str, pd.DataFrame]) -> dict[str, pd.DataFrame]:
"""Restores a given tableset (presumably output from some MultiTable workflow,
i.e. transforms or synthetics) to its original shape (specifically, "re-nests"
Expand Down Expand Up @@ -828,7 +836,10 @@ def any_table_relationships(self) -> bool:
return number_of_edges(self.graph) > 0

def debug_summary(self) -> dict[str, Any]:
max_depth = dag_longest_path_length(self.graph)
if len(self.foreign_key_cycles) > 0:
max_depth = "indeterminate (cycles in foreign keys)"
else:
max_depth = dag_longest_path_length(self.graph)
public_table_count = len(self.list_all_tables(Scope.PUBLIC))
invented_table_count = len(self.list_all_tables(Scope.INVENTED))

Expand Down
17 changes: 17 additions & 0 deletions src/gretel_trainer/relational/multi_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def __init__(
"or set `project` to run in an existing project."
)

if len(cycles := relational_data.foreign_key_cycles) > 0:
logger.warning(
f"Detected cyclic foreign key relationships in schema: {cycles}. "
"Support for cyclic table dependencies is limited. "
"You may need to remove some foreign keys to ensure no cycles exist."
)

self._strategy = _validate_strategy(strategy)
self._set_refresh_interval(refresh_interval)
self.relational_data = relational_data
Expand Down Expand Up @@ -611,6 +618,11 @@ def run_transforms(
an additional level of privacy at the cost of referential integrity between transformed and
original data.
"""
if encode_keys and len(self.relational_data.foreign_key_cycles) > 0:
raise MultiTableException(
"Cannot encode keys when schema includes cyclic foreign key relationships."
)

if data is not None:
unrunnable_tables = [
table
Expand Down Expand Up @@ -735,6 +747,11 @@ def train_synthetics(
Train synthetic data models for the tables in the tableset,
optionally scoped by either `only` or `ignore`.
"""
if len(self.relational_data.foreign_key_cycles) > 0:
raise MultiTableException(
"Cyclic foreign key relationships are not supported by relational synthetics."
)

if config is None:
config = self._strategy.default_config

Expand Down
25 changes: 25 additions & 0 deletions tests/relational/test_relational_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,31 @@ def in_order(col, t1, t2):
assert in_order(tables, "users", "order_items")


def test_detect_cycles(ecom):
assert ecom.foreign_key_cycles == []

ecom.add_foreign_key_constraint(
table="users",
constrained_columns=["first_name"],
referred_table="users",
referred_columns=["last_name"],
)
ecom.debug_summary()

assert ecom.foreign_key_cycles == [["users"]]
assert "indeterminate" in ecom.debug_summary()["max_depth"]

ecom.add_foreign_key_constraint(
table="users",
constrained_columns=["first_name"],
referred_table="events",
referred_columns=["user_id"],
)

sorted_cycles = sorted([sorted(cycle) for cycle in ecom.foreign_key_cycles])
assert sorted_cycles == [["events", "users"], ["users"]]


def test_debug_summary(ecom, mutagenesis):
assert ecom.debug_summary() == {
"foreign_key_count": 6,
Expand Down

0 comments on commit 35684de

Please sign in to comment.