diff --git a/CHANGELOG.md b/CHANGELOG.md index 2df44aca..89f8bbff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added ### Changed +- Removed implicit clones in `StypeEncoder` ([#286](https://github.com/pyg-team/pytorch-frame/pull/286)) ### Deprecated diff --git a/test/nn/encoder/test_stype_encoder.py b/test/nn/encoder/test_stype_encoder.py index ff4cff50..5038d67f 100644 --- a/test/nn/encoder/test_stype_encoder.py +++ b/test/nn/encoder/test_stype_encoder.py @@ -1,3 +1,5 @@ +import copy + import pytest import torch from torch.nn import ReLU @@ -5,7 +7,6 @@ import torch_frame from torch_frame import NAStrategy, stype from torch_frame.config import ModelConfig -from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.config.text_tokenizer import TextTokenizerConfig from torch_frame.data.dataset import Dataset from torch_frame.data.stats import StatType @@ -22,7 +23,6 @@ StackEncoder, TimestampEncoder, ) -from torch_frame.testing.text_embedder import HashTextEmbedder from torch_frame.testing.text_tokenizer import ( RandomTextModel, WhiteSpaceHashTokenizer, @@ -44,10 +44,12 @@ def test_categorical_feature_encoder(encoder_cls_kwargs): stype=stype.categorical, **encoder_cls_kwargs[1], ) - feat_cat = tensor_frame.feat_dict[stype.categorical] + feat_cat = tensor_frame.feat_dict[stype.categorical].clone() col_names = tensor_frame.col_names_dict[stype.categorical] x = encoder(feat_cat, col_names) assert x.shape == (feat_cat.size(0), feat_cat.size(1), 8) + # Make sure no in-place modification + assert torch.allclose(feat_cat, tensor_frame.feat_dict[stype.categorical]) # Perturb the first column num_categories = len(stats_list[0][StatType.COUNT]) @@ -96,10 +98,12 @@ def test_numerical_feature_encoder(encoder_cls_kwargs): stype=stype.numerical, **encoder_cls_kwargs[1], ) - feat_num = tensor_frame.feat_dict[stype.numerical] + feat_num = tensor_frame.feat_dict[stype.numerical].clone() col_names = tensor_frame.col_names_dict[stype.numerical] x = encoder(feat_num, col_names) assert x.shape == (feat_num.size(0), feat_num.size(1), 8) + # Make sure no in-place modification + assert torch.allclose(feat_num, tensor_frame.feat_dict[stype.numerical]) if "post_module" in encoder_cls_kwargs[1]: assert encoder.post_module is not None else: @@ -142,9 +146,16 @@ def test_multicategorical_feature_encoder(encoder_cls_kwargs): stype=stype.multicategorical, **encoder_cls_kwargs[1], ) - feat_multicat = tensor_frame.feat_dict[stype.multicategorical] + feat_multicat = tensor_frame.feat_dict[stype.multicategorical].clone() col_names = tensor_frame.col_names_dict[stype.multicategorical] x = encoder(feat_multicat, col_names) + # Make sure no in-place modification + assert torch.allclose( + feat_multicat.values, + tensor_frame.feat_dict[stype.multicategorical].values) + assert torch.allclose( + feat_multicat.offset, + tensor_frame.feat_dict[stype.multicategorical].offset) assert x.shape == (feat_multicat.size(0), feat_multicat.size(1), 8) # Perturb the first column @@ -178,9 +189,12 @@ def test_timestamp_feature_encoder(encoder_cls_kwargs): stype=stype.timestamp, **encoder_cls_kwargs[1], ) - feat_timestamp = tensor_frame.feat_dict[stype.timestamp] + feat_timestamp = tensor_frame.feat_dict[stype.timestamp].clone() col_names = tensor_frame.col_names_dict[stype.timestamp] x = encoder(feat_timestamp, col_names) + # Make sure no in-place modification + assert torch.allclose(feat_timestamp, + tensor_frame.feat_dict[stype.timestamp]) assert x.shape == (feat_timestamp.size(0), feat_timestamp.size(1), 8) @@ -324,40 +338,6 @@ def test_timestamp_feature_encoder_with_nan(encoder_cls_kwargs): assert (~torch.isnan(x)).all() -def test_text_embedded_encoder(): - num_rows = 20 - text_emb_channels = 10 - out_channels = 5 - dataset = FakeDataset( - num_rows=num_rows, - stypes=[ - torch_frame.text_embedded, - ], - col_to_text_embedder_cfg=TextEmbedderConfig( - text_embedder=HashTextEmbedder(text_emb_channels), - batch_size=None), - ) - dataset.materialize() - tensor_frame = dataset.tensor_frame - stats_list = [ - dataset.col_stats[col_name] - for col_name in tensor_frame.col_names_dict[stype.embedding] - ] - encoder = LinearEmbeddingEncoder( - out_channels=out_channels, - stats_list=stats_list, - stype=stype.embedding, - ) - feat_text = tensor_frame.feat_dict[stype.embedding] - col_names = tensor_frame.col_names_dict[stype.embedding] - feat = encoder(feat_text, col_names) - assert feat.shape == ( - num_rows, - len(tensor_frame.col_names_dict[stype.embedding]), - out_channels, - ) - - def test_embedding_encoder(): num_rows = 20 out_channels = 5 @@ -378,9 +358,14 @@ def test_embedding_encoder(): stats_list=stats_list, stype=stype.embedding, ) - feat_text = tensor_frame.feat_dict[stype.embedding] + feat_emb = tensor_frame.feat_dict[stype.embedding].clone() col_names = tensor_frame.col_names_dict[stype.embedding] - x = encoder(feat_text, col_names) + x = encoder(feat_emb, col_names) + # Make sure no in-place modification + assert torch.allclose(feat_emb.values, + tensor_frame.feat_dict[stype.embedding].values) + assert torch.allclose(feat_emb.offset, + tensor_frame.feat_dict[stype.embedding].offset) assert x.shape == ( num_rows, len(tensor_frame.col_names_dict[stype.embedding]), @@ -421,7 +406,7 @@ def test_text_tokenized_encoder(): stype=stype.text_tokenized, col_to_model_cfg=col_to_model_cfg, ) - feat_text = tensor_frame.feat_dict[stype.text_tokenized] + feat_text = copy.deepcopy(tensor_frame.feat_dict[stype.text_tokenized]) col_names = tensor_frame.col_names_dict[stype.text_tokenized] x = encoder(feat_text, col_names) assert x.shape == ( @@ -429,3 +414,15 @@ def test_text_tokenized_encoder(): len(tensor_frame.col_names_dict[stype.text_tokenized]), out_channels, ) + # Make sure no in-place modification + assert isinstance(feat_text, dict) and isinstance( + tensor_frame.feat_dict[stype.text_tokenized], dict) + assert feat_text.keys() == tensor_frame.feat_dict[ + stype.text_tokenized].keys() + for key in feat_text.keys(): + assert torch.allclose( + feat_text[key].values, + tensor_frame.feat_dict[stype.text_tokenized][key].values) + assert torch.allclose( + feat_text[key].offset, + tensor_frame.feat_dict[stype.text_tokenized][key].offset) diff --git a/test/nn/models/test_compile.py b/test/nn/models/test_compile.py index d14754ae..ec53c0d7 100644 --- a/test/nn/models/test_compile.py +++ b/test/nn/models/test_compile.py @@ -34,7 +34,7 @@ gamma=0.1, ), None, - 4, + 7, id="TabNet", ), pytest.param( @@ -54,7 +54,7 @@ Trompt, dict(channels=8, num_prompts=2), None, - 11, + 16, id="Trompt", ), pytest.param( diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 65cf046f..d896340b 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -37,6 +37,19 @@ def reset_parameters_soft(module: Module): module.reset_parameters() +def get_na_mask(tensor: Tensor) -> Tensor: + r"""Obtains the Na maks of the input :obj:`Tensor`. + + Args: + tensor (Tensor): Input :obj:`Tensor`. + """ + if tensor.is_floating_point(): + na_mask = torch.isnan(tensor) + else: + na_mask = tensor == -1 + return na_mask + + class StypeEncoder(Module, ABC): r"""Base class for stype encoder. This module transforms tensor of a specific stype, i.e., `TensorFrame.feat_dict[stype.xxx]` into 3-dimensional @@ -121,11 +134,6 @@ def forward( f"The number of columns in feat and the length of " f"col_names must match (got {num_cols} and " f"{len(col_names)}, respectively.)") - # Clone the tensor to avoid in-place modification - if not isinstance(feat, dict): - feat = feat.clone() - else: - feat = {key: value.clone() for key, value in feat.items()} # NaN handling of the input Tensor feat = self.na_forward(feat) # Main encoding into column embeddings @@ -174,20 +182,36 @@ def na_forward(self, feat: TensorData) -> TensorData: """ if self.na_strategy is None: return feat - for col in range(feat.size(1)): - column_data = feat[:, col] - if isinstance(feat, _MultiTensor): - column_data = column_data.values - if column_data.is_floating_point(): - nan_mask = torch.isnan(column_data) + + # Since we are not changing the number of items in each column, it's + # faster to just clone the values, while reusing the same offset + # object. + if isinstance(feat, Tensor): + if get_na_mask(feat).any(): + feat = feat.clone() + else: + return feat + elif isinstance(feat, MultiEmbeddingTensor): + if get_na_mask(feat.values).any(): + feat = MultiEmbeddingTensor(num_rows=feat.num_rows, + num_cols=feat.num_cols, + values=feat.values.clone(), + offset=feat.offset) + else: + return feat + elif isinstance(feat, MultiNestedTensor): + if get_na_mask(feat.values).any(): + feat = MultiNestedTensor(num_rows=feat.num_rows, + num_cols=feat.num_cols, + values=feat.values.clone(), + offset=feat.offset) else: - nan_mask = column_data == -1 - if nan_mask.ndim == 2: - nan_mask = nan_mask.any(dim=-1) - assert nan_mask.ndim == 1 - assert len(nan_mask) == len(column_data) - if not nan_mask.any(): - continue + return feat + else: + raise ValueError(f"Unrecognized type {type(feat)} in na_forward.") + + # TODO: Remove for-loop over columns + for col in range(feat.size(1)): if self.na_strategy == NAStrategy.MOST_FREQUENT: # Categorical index is sorted based on count, # so 0-th index is always the most frequent. @@ -210,7 +234,13 @@ def na_forward(self, feat: TensorData) -> TensorData: if isinstance(feat, _MultiTensor): feat.fillna_col(col, fill_value) else: - column_data[nan_mask] = fill_value + column_data = feat[:, col] + na_mask = get_na_mask(column_data) + if na_mask.ndim == 2: + na_mask = na_mask.any(dim=-1) + assert na_mask.ndim == 1 + assert len(na_mask) == len(column_data) + column_data[na_mask] = fill_value # Add better safeguard here to make sure nans are actually # replaced, expecially when nans are represented as -1's. They are # very hard to catch as they won't error out. @@ -339,11 +369,10 @@ def encode_forward( # Increment the index by one so that NaN index (-1) becomes 0 # (padding_idx) # feat: [batch_size, num_cols] - feat.values = feat.values + 1 xs = [] for i, emb in enumerate(self.embs): col_feat = feat[:, i] - xs.append(emb(col_feat.values, col_feat.offset[:-1])) + xs.append(emb(col_feat.values + 1, col_feat.offset[:-1])) # [batch_size, num_cols, hidden_channels] x = torch.stack(xs, dim=1) return x