Skip to content

Commit

Permalink
PLAT-1018: Add support for nullable single-column foreign keys
Browse files Browse the repository at this point in the history
GitOrigin-RevId: b4a72e2c87800f76e9d8a5338a68681161ca7fc0
  • Loading branch information
mikeknep committed Dec 7, 2023
1 parent b1174f0 commit 3d11395
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 5 deletions.
3 changes: 3 additions & 0 deletions src/gretel_trainer/relational/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class ForeignKey:
parent_table_name: str
parent_columns: list[str]

def is_composite(self) -> bool:
return len(self.columns) > 1


UserFriendlyDataT = Union[pd.DataFrame, str, Path]
UserFriendlyPrimaryKeyT = Optional[Union[str, list[str]]]
Expand Down
28 changes: 28 additions & 0 deletions src/gretel_trainer/relational/strategies/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import random

from dataclasses import dataclass
from typing import Optional

import pandas as pd
Expand Down Expand Up @@ -177,3 +178,30 @@ def _get_most_unique_column(pk: list[str], col_freqs: dict[str, list]) -> str:

def get_frequencies(table_data: pd.DataFrame, cols: list[str]) -> list[int]:
return list(table_data.groupby(cols).size().reset_index()[0])


# Frequency metadata for a list of columns (typically a foreign key).
#
# Example: pd.DataFrame(data={
# "col_1": ["a", "a", "b"],
# "col_2": [100, 100, None],
# })
#
# null_percentages: [0.0, 0.33333]
# (col_1 has no null values; col_2 is 1/3 null)
# not_null_frequencies: [2]
# (["a", 100] occurs twice; there are no other non-null values)
@dataclass
class FrequencyData:
null_percentages: list[float]
not_null_frequencies: list[int]

@classmethod
def for_columns(cls, table_data: pd.DataFrame, cols: list[str]):
null_percentages = (table_data[cols].isnull().sum() / len(table_data)).tolist()
not_null_frequencies = table_data.groupby(cols).size().tolist()

return FrequencyData(
null_percentages=null_percentages,
not_null_frequencies=not_null_frequencies,
)
50 changes: 45 additions & 5 deletions src/gretel_trainer/relational/strategies/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import gretel_trainer.relational.strategies.common as common

from gretel_trainer.relational.core import GretelModelConfig, RelationalData
from gretel_trainer.relational.core import ForeignKey, GretelModelConfig, RelationalData
from gretel_trainer.relational.output_handler import OutputHandler

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -225,12 +225,12 @@ def _synthesize_foreign_keys(
].values.tolist()

original_table_data = rel_data.get_table_data(table_name)
fk_frequencies = common.get_frequencies(
fk_frequency_data = common.FrequencyData.for_columns(
original_table_data, foreign_key.columns
)

new_fk_values = _collect_values(
synth_parent_values, fk_frequencies, len(out_df)
new_fk_values = _collect_fk_values(
synth_parent_values, fk_frequency_data, len(out_df), foreign_key
)

out_df[foreign_key.columns] = new_fk_values
Expand All @@ -240,10 +240,51 @@ def _synthesize_foreign_keys(
return processed


def _collect_fk_values(
values: list,
freq_data: common.FrequencyData,
total: int,
foreign_key: ForeignKey,
) -> list:
# Support for and restrictions on null values in composite foreign keys varies
# across database dialects. The simplest and safest thing to do here is
# exclusively produce composite foreign key values that contain no NULLs.
if foreign_key.is_composite():
return _collect_values(values, freq_data.not_null_frequencies, total, [])

# Here, the foreign key is a single column. Start by adding an appropriate
# amount of NULL values.
num_nulls = round(freq_data.null_percentages[0] * total)
new_values = [(None,)] * num_nulls

# Dedupe repeated values and discard None if present.
# 1. Duplicates should not exist, because foreign keys are required to reference
# columns with unique values. If the referred column is a PRIMARY KEY, there
# *will not* be duplicates given how we synthesize primary keys. However, the
# foreign key could be referencing a non-PK column with a UNIQUE constraint;
# in that case, our model will not have known about the UNIQUE consraint and
# may have produced non-unique values.
# 2. Nulls are already accounted for above; we don't want to synthesize more.
def _unique_not_null_values(values: list) -> list:
unique_values = {tuple(v) for v in values}
unique_values.discard((None,))
return list(unique_values)

# Collect final output values by adding non-null values to `new_values`
# (which has the requisite number of nulls already).
return _collect_values(
_unique_not_null_values(values),
freq_data.not_null_frequencies,
total,
new_values,
)


def _collect_values(
values: list,
frequencies: list[int],
total: int,
new_values: list,
) -> list:
freqs = sorted(frequencies)

Expand All @@ -252,7 +293,6 @@ def _collect_values(
# to the output collection
v = 0
f = 0
new_values = []
while len(new_values) < total:
fk_value = values[v]

Expand Down
122 changes: 122 additions & 0 deletions tests/relational/test_independent_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,3 +244,125 @@ def test_post_processing_fks_to_non_pks(tmpdir):

for table in rel_data.list_all_tables():
assert set(processed[table].columns) == set(rel_data.get_table_columns(table))


def test_post_processing_null_foreign_key(tmpdir):
rel_data = RelationalData(directory=tmpdir)

rel_data.add_table(
name="customers",
primary_key="id",
data=pd.DataFrame(data={"id": [1, 2], "name": ["Xavier", "Yesenia"]}),
)
rel_data.add_table(
name="events",
primary_key="id",
data=pd.DataFrame(
data={
"id": [1, 2],
"customer_id": [1, None],
"total": [42, 43],
}
),
)
rel_data.add_foreign_key_constraint(
table="events",
constrained_columns=["customer_id"],
referred_table="customers",
referred_columns=["id"],
)

strategy = IndependentStrategy()

raw_synth_tables = {
"events": pd.DataFrame(data={"total": [55, 56, 57, 58]}),
"customers": pd.DataFrame(
data={"name": ["Alice", "Bob", "Christina", "David"]}
),
}

# Patch shuffle for deterministic testing, but don't swap in `sorted`
# because that function doesn't cooperate with `None` (raises TypeError)
with patch("random.shuffle", wraps=lambda x: x):
processed = strategy.post_process_synthetic_results(
raw_synth_tables, [], rel_data, 2
)

# Given 50% of source FKs are null and record_size_ratio=2,
# we expect 2/4 customer_ids to be null
pdtest.assert_frame_equal(
processed["events"],
pd.DataFrame(
data={
"total": [55, 56, 57, 58],
"id": [0, 1, 2, 3],
"customer_id": [None, None, 0, 1],
}
),
)


def test_post_processing_null_composite_foreign_key(tmpdir):
rel_data = RelationalData(directory=tmpdir)

rel_data.add_table(
name="customers",
primary_key="id",
data=pd.DataFrame(
data={
"id": [1, 2],
"first": ["Albert", "Betsy"],
"last": ["Anderson", "Bond"],
}
),
)
rel_data.add_table(
name="events",
primary_key="id",
data=pd.DataFrame(
data={
"id": [1, 2, 3, 4, 5],
"customer_first": ["Albert", "Betsy", None, "Betsy", None],
"customer_last": ["Anderson", "Bond", None, None, "Bond"],
"total": [42, 43, 44, 45, 46],
}
),
)
rel_data.add_foreign_key_constraint(
table="events",
constrained_columns=["customer_first", "customer_last"],
referred_table="customers",
referred_columns=["first", "last"],
)

strategy = IndependentStrategy()

raw_synth_tables = {
"events": pd.DataFrame(data={"total": [55, 56, 57, 58, 59]}),
"customers": pd.DataFrame(
data={
"first": ["Herbert", "Isabella", "Jack", "Kevin", "Louise"],
"last": ["Hoover", "Irvin", "Johnson", "Knight", "Lane"],
}
),
}

# Patch shuffle for deterministic testing
with patch("random.shuffle", wraps=sorted):
processed = strategy.post_process_synthetic_results(
raw_synth_tables, [], rel_data, 1
)

# We do not create composite foreign key values with nulls,
# even if some existed in the source data.
pdtest.assert_frame_equal(
processed["events"],
pd.DataFrame(
data={
"total": [55, 56, 57, 58, 59],
"id": [0, 1, 2, 3, 4],
"customer_first": ["Herbert", "Isabella", "Jack", "Kevin", "Louise"],
"customer_last": ["Hoover", "Irvin", "Johnson", "Knight", "Lane"],
}
),
)

0 comments on commit 3d11395

Please sign in to comment.