diff --git a/CHANGELOG.md b/CHANGELOG.md index 168bab1e1..c249fe297 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a benchmark script to compare PyTorch Frame with PyTorch Tabular ([#398](https://github.com/pyg-team/pytorch-frame/pull/398), [#444](https://github.com/pyg-team/pytorch-frame/pull/444)) - Added `is_floating_point` method to `MultiNestedTensor` and `MultiEmbeddingTensor` ([#445](https://github.com/pyg-team/pytorch-frame/pull/445)) - Added support for inferring `stype.categorical` from boolean columns in `utils.infer_series_stype` ([#421](https://github.com/pyg-team/pytorch-frame/pull/421)) +- Added `pin_memory()` to `TensorFrame`, `MultiEmbeddingTensor`, and `MultiNestedTensor` ([#437](https://github.com/pyg-team/pytorch-frame/pull/437)) ### Changed diff --git a/test/data/test_multi_embedding_tensor.py b/test/data/test_multi_embedding_tensor.py index e5c29f968..01cb1aebe 100644 --- a/test/data/test_multi_embedding_tensor.py +++ b/test/data/test_multi_embedding_tensor.py @@ -5,7 +5,7 @@ import torch from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor -from torch_frame.testing import withCUDA +from torch_frame.testing import onlyCUDA, withCUDA def assert_equal( @@ -487,3 +487,18 @@ def test_cat(device): # case: list of non-MultiEmbeddingTensor should raise error with pytest.raises(AssertionError): MultiEmbeddingTensor.cat([object()], dim=0) + + +@onlyCUDA +def test_pin_memory(): + met, _ = get_fake_multi_embedding_tensor( + num_rows=2, + num_cols=3, + ) + assert not met.is_pinned() + assert not met.values.is_pinned() + assert not met.offset.is_pinned() + met = met.pin_memory() + assert met.is_pinned() + assert met.values.is_pinned() + assert met.offset.is_pinned() diff --git a/test/data/test_multi_nested_tensor.py b/test/data/test_multi_nested_tensor.py index 29bd8bdea..8ed084f77 100644 --- a/test/data/test_multi_nested_tensor.py +++ b/test/data/test_multi_nested_tensor.py @@ -6,7 +6,7 @@ from torch import Tensor from torch_frame.data import MultiNestedTensor -from torch_frame.testing import withCUDA +from torch_frame.testing import onlyCUDA def assert_equal(tensor_mat: list[list[Tensor]], @@ -95,8 +95,8 @@ def test_fillna_col(): torch.tensor([100], dtype=torch.float32))) -@withCUDA -def test_multi_nested_tensor_basics(device): +@onlyCUDA +def test_basics(device): num_rows = 8 num_cols = 10 max_value = 100 @@ -326,7 +326,7 @@ def test_multi_nested_tensor_basics(device): cloned_multi_nested_tensor) -def test_multi_nested_tensor_different_num_rows(): +def test_different_num_rows(): tensor_mat = [ [torch.tensor([1, 2, 3]), torch.tensor([4, 5])], @@ -340,3 +340,20 @@ def test_multi_nested_tensor_different_num_rows(): match="The length of each row must be the same", ): MultiNestedTensor.from_tensor_mat(tensor_mat) + + +@onlyCUDA +def test_pin_memory(): + num_rows = 10 + num_cols = 3 + tensor = MultiNestedTensor.from_tensor_mat( + [[torch.randn(random.randint(0, 10)) for _ in range(num_cols)] + for _ in range(num_rows)]) + + assert not tensor.is_pinned() + assert not tensor.values.is_pinned() + assert not tensor.offset.is_pinned() + tensor = tensor.pin_memory() + assert tensor.is_pinned() + assert tensor.values.is_pinned() + assert tensor.offset.is_pinned() diff --git a/test/data/test_tensor_frame.py b/test/data/test_tensor_frame.py index a0645b183..bf6355f3f 100644 --- a/test/data/test_tensor_frame.py +++ b/test/data/test_tensor_frame.py @@ -7,6 +7,7 @@ from torch_frame import TensorFrame from torch_frame.data.multi_embedding_tensor import MultiEmbeddingTensor from torch_frame.data.multi_nested_tensor import MultiNestedTensor +from torch_frame.testing import onlyCUDA def test_tensor_frame_basics(get_fake_tensor_frame): @@ -253,3 +254,19 @@ def test_non_list_col_names_dict(): col_names_dict = {torch_frame.categorical: 'cat_1'} with pytest.raises(ValueError, match='must be a list of column names'): TensorFrame(feat_dict, col_names_dict) + + +@onlyCUDA +def test_pin_memory(get_fake_tensor_frame): + def assert_is_pinned(tf: TensorFrame, expected: bool) -> bool: + for value in tf.feat_dict.values(): + if isinstance(value, dict): + for v in value.values(): + assert v.is_pinned() is expected + else: + assert value.is_pinned() is expected + + tf = get_fake_tensor_frame(10) + assert_is_pinned(tf, expected=False) + tf = tf.pin_memory() + assert_is_pinned(tf, expected=True) diff --git a/torch_frame/data/multi_tensor.py b/torch_frame/data/multi_tensor.py index b6c201220..e812a8af8 100644 --- a/torch_frame/data/multi_tensor.py +++ b/torch_frame/data/multi_tensor.py @@ -97,6 +97,12 @@ def cpu(self, *args, **kwargs): def cuda(self, *args, **kwargs): return self._apply(lambda x: x.cuda(*args, **kwargs)) + def pin_memory(self, *args, **kwargs): + return self._apply(lambda x: x.pin_memory(*args, **kwargs)) + + def is_pinned(self) -> bool: + return self.values.is_pinned() and self.offset.is_pinned() + # Helper Functions ######################################################## def _apply(self, fn: Callable[[Tensor], Tensor]) -> _MultiTensor: diff --git a/torch_frame/data/tensor_frame.py b/torch_frame/data/tensor_frame.py index e88d8137e..00b026705 100644 --- a/torch_frame/data/tensor_frame.py +++ b/torch_frame/data/tensor_frame.py @@ -356,6 +356,17 @@ def fn(x): return self._apply(fn) + def pin_memory(self, *args, **kwargs): + def fn(x): + if isinstance(x, dict): + for key in x: + x[key] = x[key].pin_memory(*args, **kwargs) + else: + x = x.pin_memory(*args, **kwargs) + return x + + return self._apply(fn) + # Helper Functions ######################################################## def _apply(self, fn: Callable[[TensorData], TensorData]) -> TensorFrame: diff --git a/torch_frame/testing/__init__.py b/torch_frame/testing/__init__.py index 0003db38a..03d0f1a46 100644 --- a/torch_frame/testing/__init__.py +++ b/torch_frame/testing/__init__.py @@ -3,10 +3,12 @@ has_package, withPackage, withCUDA, + onlyCUDA, ) __all__ = [ 'has_package', 'withPackage', 'withCUDA', + 'onlyCUDA', ] diff --git a/torch_frame/testing/decorators.py b/torch_frame/testing/decorators.py index 7c4f945b7..7258eb521 100644 --- a/torch_frame/testing/decorators.py +++ b/torch_frame/testing/decorators.py @@ -48,3 +48,12 @@ def withCUDA(func: Callable): devices.append(pytest.param(torch.device('cuda:0'), id='cuda:0')) return pytest.mark.parametrize('device', devices)(func) + + +def onlyCUDA(func: Callable) -> Callable: + r"""A decorator to skip tests if CUDA is not found.""" + import pytest + return pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", + )(func)