From e278402e0c23d01e76778023ec0a6f9b9b8b8475 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 14 Aug 2024 03:39:13 +0900 Subject: [PATCH] Don't create the same tensor every iteration in N/A handling (#434) See title. --- CHANGELOG.md | 2 +- torch_frame/nn/encoder/stype_encoder.py | 88 ++++++++++++------------- 2 files changed, 43 insertions(+), 47 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a903868e..6e9c335f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Removed CUDA synchronizations from `nn.LinearEmbeddingEncoder` ([#432](https://github.com/pyg-team/pytorch-frame/pull/432)) -- Removed CUDA synchronizations from N/A imputation logic in `nn.StypeEncoder` ([#433](https://github.com/pyg-team/pytorch-frame/pull/433)) +- Removed CUDA synchronizations from N/A imputation logic in `nn.StypeEncoder` ([#433](https://github.com/pyg-team/pytorch-frame/pull/433), [#434](https://github.com/pyg-team/pytorch-frame/pull/434)) ## [0.2.3] - 2024-07-08 diff --git a/torch_frame/nn/encoder/stype_encoder.py b/torch_frame/nn/encoder/stype_encoder.py index 86f1d723..88a0e5be 100644 --- a/torch_frame/nn/encoder/stype_encoder.py +++ b/torch_frame/nn/encoder/stype_encoder.py @@ -109,6 +109,35 @@ def init_modules(self): f"can be used on {self.stype} columns, but " f"{self.na_strategy} is given.") + fill_values = [] + for col in range(len(self.stats_list)): + if self.na_strategy == NAStrategy.MOST_FREQUENT: + # Categorical index is sorted based on count, + # so 0-th index is always the most frequent. + fill_value = 0 + elif self.na_strategy == NAStrategy.MEAN: + fill_value = self.stats_list[col][StatType.MEAN] + elif self.na_strategy == NAStrategy.ZEROS: + fill_value = 0 + elif self.na_strategy == NAStrategy.NEWEST_TIMESTAMP: + fill_value = self.stats_list[col][StatType.NEWEST_TIME] + elif self.na_strategy == NAStrategy.OLDEST_TIMESTAMP: + fill_value = self.stats_list[col][StatType.OLDEST_TIME] + elif self.na_strategy == NAStrategy.MEDIAN_TIMESTAMP: + fill_value = self.stats_list[col][StatType.MEDIAN_TIME] + else: + raise ValueError( + f"Unsupported NA strategy {self.na_strategy}") + fill_values.append(fill_value) + + if (isinstance(fill_values[0], Tensor) + and fill_values[0].size(0) > 1): + fill_values = torch.stack(fill_values) + else: + fill_values = torch.tensor(fill_values) + + self.register_buffer("fill_values", fill_values) + @abstractmethod def reset_parameters(self): r"""Initialize the parameters of `post_module`.""" @@ -190,66 +219,33 @@ def na_forward(self, feat: TensorData) -> TensorData: if isinstance(feat, Tensor): # cache for future use na_mask = get_na_mask(feat) - if na_mask.any(): - feat = feat.clone() - else: - return feat + feat = feat.clone() elif isinstance(feat, MultiEmbeddingTensor): - if get_na_mask(feat.values).any(): - feat = MultiEmbeddingTensor(num_rows=feat.num_rows, - num_cols=feat.num_cols, - values=feat.values.clone(), - offset=feat.offset) - else: - return feat + feat = MultiEmbeddingTensor(num_rows=feat.num_rows, + num_cols=feat.num_cols, + values=feat.values.clone(), + offset=feat.offset) elif isinstance(feat, MultiNestedTensor): - if get_na_mask(feat.values).any(): - feat = MultiNestedTensor(num_rows=feat.num_rows, - num_cols=feat.num_cols, - values=feat.values.clone(), - offset=feat.offset) - else: - return feat + feat = MultiNestedTensor(num_rows=feat.num_rows, + num_cols=feat.num_cols, + values=feat.values.clone(), + offset=feat.offset) else: raise ValueError(f"Unrecognized type {type(feat)} in na_forward.") - fill_values = [] - for col in range(feat.size(1)): - if self.na_strategy == NAStrategy.MOST_FREQUENT: - # Categorical index is sorted based on count, - # so 0-th index is always the most frequent. - fill_value = 0 - elif self.na_strategy == NAStrategy.MEAN: - fill_value = self.stats_list[col][StatType.MEAN] - elif self.na_strategy == NAStrategy.ZEROS: - fill_value = 0 - elif self.na_strategy == NAStrategy.NEWEST_TIMESTAMP: - fill_value = self.stats_list[col][StatType.NEWEST_TIME].to( - feat.device) - elif self.na_strategy == NAStrategy.OLDEST_TIMESTAMP: - fill_value = self.stats_list[col][StatType.OLDEST_TIME].to( - feat.device) - elif self.na_strategy == NAStrategy.MEDIAN_TIMESTAMP: - fill_value = self.stats_list[col][StatType.MEDIAN_TIME].to( - feat.device) - else: - raise ValueError(f"Unsupported NA strategy {self.na_strategy}") - fill_values.append(fill_value) - if isinstance(feat, _MultiTensor): - for col, fill_value in enumerate(fill_values): + for col, fill_value in enumerate(self.fill_values): feat.fillna_col(col, fill_value) else: if na_mask.ndim == 3: # when feat is 3D, it is faster to iterate over columns - for col, fill_value in enumerate(fill_values): + for col, fill_value in enumerate(self.fill_values): col_data = feat[:, col] col_na_mask = na_mask[:, col].any(dim=-1) col_data[col_na_mask] = fill_value else: # na_mask.ndim == 2 - fill_values = torch.tensor(fill_values, device=feat.device) - assert feat.size(-1) == fill_values.size(-1) - feat = torch.where(na_mask, fill_values, feat) + assert feat.size(-1) == self.fill_values.size(-1) + feat = torch.where(na_mask, self.fill_values, feat) return feat