Skip to content

Commit

Permalink
Materialize train test (#472)
Browse files Browse the repository at this point in the history
Closes #470.

Add the col_stats in the dataset materialize function, allowing the
column stats calculated from train dataset to be used in the test data
materialization.

```python
train_df = ..
test_df = ..
train_dataset = Dataset(train_df, 
                 col_to_stype=col_to_stype, 
                 target_col="target")
train_dataset.materialize()
test_dataset = Dataset(test_df, 
                 col_to_stype=col_to_stype, 
                 target_col="target")
test_dataset.materialize(col_stats=train_dataset.col_stats)
```

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
HoustonJ2013 and pre-commit-ci[bot] authored Dec 27, 2024
1 parent ecff54e commit 896919e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for materializing dataset for train and test dataframe separately([#470](https://github.com/pyg-team/pytorch-frame/issues/470))
- Added support for PyTorch 2.5 ([#464](https://github.com/pyg-team/pytorch-frame/pull/464))
- Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444))
- Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445))
Expand Down
41 changes: 41 additions & 0 deletions test/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,44 @@ def test_col_to_pattern_raise_error():
dataset = FakeDataset(num_rows=10, stypes=[torch_frame.timestamp])
Dataset(dataset.df, dataset.col_to_stype, dataset.target_col,
col_to_time_format=2)


def test_materialization_with_col_stats(tmpdir):
tmp_path = str(tmpdir.mkdir("image"))
text_embedder_cfg = TextEmbedderConfig(
text_embedder=HashTextEmbedder(1),
batch_size=8,
)
image_embedder_cfg = ImageEmbedderConfig(
image_embedder=RandomImageEmbedder(1),
batch_size=8,
)
dataset_stypes = [
torch_frame.categorical,
torch_frame.numerical,
torch_frame.multicategorical,
torch_frame.sequence_numerical,
torch_frame.timestamp,
torch_frame.text_embedded,
torch_frame.embedding,
torch_frame.image_embedded,
]
train_dataset = FakeDataset(
num_rows=10,
stypes=dataset_stypes,
col_to_text_embedder_cfg=text_embedder_cfg,
col_to_image_embedder_cfg=image_embedder_cfg,
tmp_path=tmp_path,
)
train_dataset.materialize() # materialize to compute col_stats
test_dataset = FakeDataset(
num_rows=5,
stypes=dataset_stypes,
col_to_text_embedder_cfg=text_embedder_cfg,
col_to_image_embedder_cfg=image_embedder_cfg,
tmp_path=tmp_path,
)
test_dataset.materialize(col_stats=train_dataset.col_stats)

assert train_dataset.col_stats == test_dataset.col_stats, \
"col_stats should be the same for train and test datasets"
53 changes: 36 additions & 17 deletions torch_frame/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,7 @@ def materialize(
self,
device: torch.device | None = None,
path: str | None = None,
col_stats: dict[str, dict[StatType, Any]] | None = None,
) -> Dataset:
r"""Materializes the dataset into a tensor representation. From this
point onwards, the dataset should be treated as read-only.
Expand All @@ -570,6 +571,10 @@ def materialize(
:obj:`path`. If :obj:`path` is :obj:`None`, this will
materialize the dataset without caching.
(default: :obj:`None`)
col_stats (Dict[str, Dict[StatType, Any]], optional): optional
col_stats provided by the user. If not provided, the statistics
is calculated from the dataframe itself. (default: :obj:`None`)
"""
if self.is_materialized:
# Materialized without specifying path at first and materialize
Expand All @@ -589,23 +594,37 @@ def materialize(
return self

# 1. Fill column statistics:
for col, stype in self.col_to_stype.items():
ser = self.df[col]
self._col_stats[col] = compute_col_stats(
ser,
stype,
sep=self.col_to_sep.get(col, None),
time_format=self.col_to_time_format.get(col, None),
)
# For a target column, sort categories lexicographically such that
# we do not accidentally swap labels in binary classification
# tasks.
if col == self.target_col and stype == torch_frame.categorical:
index, value = self._col_stats[col][StatType.COUNT]
if len(index) == 2:
ser = pd.Series(index=index, data=value).sort_index()
index, value = ser.index.tolist(), ser.values.tolist()
self._col_stats[col][StatType.COUNT] = (index, value)
if col_stats is None:
# calculate from data if col_stats is not provided
for col, stype in self.col_to_stype.items():
ser = self.df[col]
self._col_stats[col] = compute_col_stats(
ser,
stype,
sep=self.col_to_sep.get(col, None),
time_format=self.col_to_time_format.get(col, None),
)
# For a target column, sort categories lexicographically
# such that we do not accidentally swap labels in binary
# classification tasks.
if col == self.target_col and stype == torch_frame.categorical:
index, value = self._col_stats[col][StatType.COUNT]
if len(index) == 2:
ser = pd.Series(index=index, data=value).sort_index()
index, value = ser.index.tolist(), ser.values.tolist()
self._col_stats[col][StatType.COUNT] = (index, value)
else:
# basic validation for the col_stats provided by the user
for col_, stype_ in self.col_to_stype.items():
assert col_ in col_stats, \
f"{col_} is not specified in the provided col_stats"
stats_ = col_stats[col_]
assert all([key_ in stats_
for key_ in StatType.stats_for_stype(stype_)]), \
"not all required stats are calculated" \
f" in the provided col_stats for {col}"

self._col_stats = col_stats

# 2. Create the `TensorFrame`:
self._to_tensor_frame_converter = self._get_tensorframe_converter()
Expand Down

0 comments on commit 896919e

Please sign in to comment.