diff --git a/docs/source/modules/transforms.rst b/docs/source/modules/transforms.rst index 3ebd38f2..be4dc244 100644 --- a/docs/source/modules/transforms.rst +++ b/docs/source/modules/transforms.rst @@ -18,16 +18,17 @@ Let's look an example, where we apply `CatToNumTransform >> ['C_feature_0', 'C_feature_1', 'C_feature_2', 'C_feature_3', 'C_feature_4', 'C_feature_5', 'C_feature_6', 'C_feature_7'] - test_dataset = dataset.get_split('test') + test_dataset = dataset.get_split(TrainingStage.TEST) transform.fit(train_dataset.tensor_frame, dataset.col_stats) transformed_col_stats = transform.transformed_stats diff --git a/test/transforms/test_mutual_information_sort.py b/test/transforms/test_mutual_information_sort.py index 1a06e4d7..98cf0f2e 100644 --- a/test/transforms/test_mutual_information_sort.py +++ b/test/transforms/test_mutual_information_sort.py @@ -5,6 +5,7 @@ from torch_frame.data import Dataset from torch_frame.datasets.fake import FakeDataset from torch_frame.transforms import MutualInformationSort +from torch_frame.typing import TrainingStage @pytest.mark.parametrize('with_nan', [True, False]) @@ -19,7 +20,7 @@ def test_mutual_information_sort(with_nan): dataset.materialize() tensor_frame: TensorFrame = dataset.tensor_frame - train_dataset = dataset.get_split('train') + train_dataset = dataset.get_split(TrainingStage.TRAIN) transform = MutualInformationSort(task_type) transform.fit(train_dataset.tensor_frame, train_dataset.col_stats) out = transform(tensor_frame) diff --git a/test/utils/test_split.py b/test/utils/test_split.py index 889a0de3..764d1bf5 100644 --- a/test/utils/test_split.py +++ b/test/utils/test_split.py @@ -1,5 +1,7 @@ import numpy as np +from torch_frame.datasets import FakeDataset +from torch_frame.typing import TrainingStage from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split @@ -9,13 +11,26 @@ def test_generate_random_split(): val_ratio = 0.1 test_ratio = 0.1 - split = generate_random_split(num_data, seed=42, train_ratio=train_ratio, - val_ratio=val_ratio) - assert (split == SPLIT_TO_NUM['train']).sum() == int(num_data * - train_ratio) - assert (split == SPLIT_TO_NUM['val']).sum() == int(num_data * val_ratio) - assert (split == SPLIT_TO_NUM['test']).sum() == int(num_data * test_ratio) + split = generate_random_split(num_data, seed=42, + ratios=[train_ratio, val_ratio]) + assert (split == SPLIT_TO_NUM.get(TrainingStage.TRAIN)).sum() == int( + num_data * train_ratio) + assert (split == SPLIT_TO_NUM.get(TrainingStage.VAL)).sum() == int( + num_data * val_ratio) + assert (split == SPLIT_TO_NUM.get(TrainingStage.TEST)).sum() == int( + num_data * test_ratio) assert np.allclose( split, np.array([0, 1, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0]), ) + + +def test_split_e2e_basic(): + # TODO: Add several more test cases using @pytest.mark.parametrize + num_rows = 10 + dataset = FakeDataset(num_rows=num_rows).materialize() + dataset.random_split([0.5, 0.2]) + train_set, val_set, test_set = dataset.split() + train_set.num_rows, val_set.num_rows == (int(10 * 0.5), int(10 * 0.2)) + if test_set is not None: + test_set.num_rows == 10 - int(10 * 0.5) - int(10 * 0.2) diff --git a/torch_frame/data/dataset.py b/torch_frame/data/dataset.py index 04b2f80b..63ecb6a2 100644 --- a/torch_frame/data/dataset.py +++ b/torch_frame/data/dataset.py @@ -33,8 +33,9 @@ IndexSelectType, TaskType, TensorData, + TrainingStage, ) -from torch_frame.utils.split import SPLIT_TO_NUM +from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split COL_TO_PATTERN_STYPE_MAPPING = { "col_to_sep": torch_frame.multicategorical, @@ -695,7 +696,7 @@ def col_select(self, cols: ColumnSelectType) -> Dataset: return dataset - def get_split(self, split: str) -> Dataset: + def get_split(self, split: TrainingStage) -> Dataset: r"""Returns a subset of the dataset that belongs to a given training split (as defined in :obj:`split_col`). @@ -707,20 +708,25 @@ def get_split(self, split: str) -> Dataset: raise ValueError( f"'get_split' is not supported for '{self}' since 'split_col' " f"is not specified.") - if split not in ["train", "val", "test"]: - raise ValueError(f"The split named '{split}' is not available. " - f"Needs to be either 'train', 'val', or 'test'.") indices = self.df.index[self.df[self.split_col] == SPLIT_TO_NUM[split]].tolist() return self[indices] - def split(self) -> tuple[Dataset, Dataset, Dataset]: - r"""Splits the dataset into training, validation and test splits.""" - return ( - self.get_split("train"), - self.get_split("val"), - self.get_split("test"), - ) + def split(self) -> tuple[Dataset, Dataset, Dataset | None]: + r"""Splits the dataset into training, validation and optionally + test splits. + """ + train_set = self.get_split(TrainingStage.TRAIN) + val_set = self.get_split(TrainingStage.VAL) + test_set = self.get_split(TrainingStage.TEST) + if test_set.num_rows == 0: + test_set = None + return train_set, val_set, test_set + + def random_split(self, ratios: list[float] | None = None): + split = generate_random_split(self.num_rows, ratios) + self.split_col = 'split' + self.df[self.split_col] = split @property @requires_post_materialization diff --git a/torch_frame/datasets/fake.py b/torch_frame/datasets/fake.py index b09e9425..c4510d00 100644 --- a/torch_frame/datasets/fake.py +++ b/torch_frame/datasets/fake.py @@ -11,7 +11,7 @@ from torch_frame import stype from torch_frame.config.text_embedder import TextEmbedderConfig from torch_frame.config.text_tokenizer import TextTokenizerConfig -from torch_frame.typing import TaskType +from torch_frame.typing import TaskType, TrainingStage from torch_frame.utils.split import SPLIT_TO_NUM TIME_FORMATS = ['%Y-%m-%d %H:%M:%S', '%Y-%m-%d', '%Y/%m/%d'] @@ -189,9 +189,9 @@ def __init__( if num_rows < 3: raise ValueError("Dataframe needs at least 3 rows to include" " each of train, val and test split.") - split = [SPLIT_TO_NUM['train']] * num_rows - split[1] = SPLIT_TO_NUM['val'] - split[2] = SPLIT_TO_NUM['test'] + split = [SPLIT_TO_NUM.get(TrainingStage.TRAIN)] * num_rows + split[1] = SPLIT_TO_NUM.get(TrainingStage.VAL) + split[2] = SPLIT_TO_NUM.get(TrainingStage.TEST) df['split'] = split super().__init__( diff --git a/torch_frame/datasets/huggingface_dataset.py b/torch_frame/datasets/huggingface_dataset.py index 27ba71ff..401eefcc 100644 --- a/torch_frame/datasets/huggingface_dataset.py +++ b/torch_frame/datasets/huggingface_dataset.py @@ -4,6 +4,7 @@ import torch_frame from torch_frame import stype +from torch_frame.typing import TrainingStage from torch_frame.utils.infer_stype import infer_df_stype from torch_frame.utils.split import SPLIT_TO_NUM @@ -105,13 +106,13 @@ def __init__( # Transform HF dataset split to `SPLIT_TO_NUM` accepted one: if "train" in split_name: - split_names.append("train") + split_names.append(TrainingStage.TRAIN) elif "val" in split_name: # Some datasets have val split name as `"validation"`, # here we transform it to `"val"`: - split_names.append("val") + split_names.append(TrainingStage.VAL) elif "test" in split_name: - split_names.append("test") + split_names.append(TrainingStage.TEST) else: raise ValueError(f"Invalid split name: '{split_name}'. " f"Expected one of the following PyTorch " diff --git a/torch_frame/datasets/mercari.py b/torch_frame/datasets/mercari.py index 5a70448b..becd0b68 100644 --- a/torch_frame/datasets/mercari.py +++ b/torch_frame/datasets/mercari.py @@ -6,6 +6,7 @@ import torch_frame from torch_frame.config.text_embedder import TextEmbedderConfig +from torch_frame.typing import TrainingStage from torch_frame.utils.split import SPLIT_TO_NUM SPLIT_COL = 'split_col' @@ -64,8 +65,8 @@ def __init__( test_path = osp.join(self.base_url, 'test_stg2.csv') self.download_url(test_path, root) df_test = pd.read_csv(test_path) - df_train[SPLIT_COL] = SPLIT_TO_NUM['train'] - df_test[SPLIT_COL] = SPLIT_TO_NUM['test'] + df_train[SPLIT_COL] = SPLIT_TO_NUM[TrainingStage.TRAIN] + df_test[SPLIT_COL] = SPLIT_TO_NUM[TrainingStage.TEST] df = pd.concat([df_train, df_test], axis=0, ignore_index=True) if num_rows is not None: df = df.head(num_rows) diff --git a/torch_frame/typing.py b/torch_frame/typing.py index 8a7b5dba..f717af1e 100644 --- a/torch_frame/typing.py +++ b/torch_frame/typing.py @@ -28,6 +28,12 @@ def supports_task_type(self, task_type: 'TaskType') -> bool: return self in task_type.supported_metrics +class TrainingStage(Enum): + TRAIN = 'train' + VAL = 'val' + TEST = 'test' + + class TaskType(Enum): r"""The type of the task. diff --git a/torch_frame/utils/split.py b/torch_frame/utils/split.py index d152854a..44951694 100644 --- a/torch_frame/utils/split.py +++ b/torch_frame/utils/split.py @@ -1,30 +1,85 @@ +import math + import numpy as np +from typing import List +from torch_frame.typing import TrainingStage + # Mapping split name to integer. -SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2} +SPLIT_TO_NUM = { + TrainingStage.TRAIN: 0, + TrainingStage.VAL: 1, + TrainingStage.TEST: 2 +} -def generate_random_split(length: int, seed: int, train_ratio: float = 0.8, - val_ratio: float = 0.1) -> np.ndarray: +def generate_random_split( + length: int, + ratios: List[float], + seed: int = 0, +) -> np.ndarray: r"""Generate a list of random split assignments of the specified length. The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing train, val, test, respectively. Note that this function relies on the fact that numpy's shuffle is consistent across versions, which has been historically the case. + + Args: + length (int): The length of the dataset. + ratios (List[float]): Ratios for split assignment. When ratios + contains 2 variables, we will generate train/val/test set + respectively based on the split ratios (the 1st variable in + the list will be the ratio for train set, the 2nd will be + the ratio for val set and the remaining data will be used + for test set). When ratios contains 1 variable, we will only + generate train/val set. (the variable in) + seed (int, optional): The seed for the randomness generator. + + Returns: + A np.ndarra object representing the split. """ - assert train_ratio + val_ratio < 1 - assert train_ratio > 0 - assert val_ratio > 0 - train_num = int(length * train_ratio) - val_num = int(length * val_ratio) - test_num = length - train_num - val_num + validate_split_ratios(ratios) + ratios_length = len(ratios) + if length < ratios_length + 1: + raise ValueError( + f"We want to split data into {ratios_length + 1} disjoint set. " + f"However data contains {length} data point. Consider " + f"increase your data size.") + + # train_num = int(length * train_ratio) + # val_num = int(length * val_ratio) + # test_num = length - train_num - val_num + train_num = math.floor(length * ratios[0]) + val_num = math.floor( + length * ratios[1]) if ratios_length == 2 else length - train_num + test_num = None + if ratios_length == 2: + test_num = length - train_num - val_num arr = np.concatenate([ - np.full(train_num, SPLIT_TO_NUM['train']), - np.full(val_num, SPLIT_TO_NUM['val']), - np.full(test_num, SPLIT_TO_NUM['test']) + np.full(train_num, SPLIT_TO_NUM.get(TrainingStage.TRAIN)), + np.full(val_num, SPLIT_TO_NUM.get(TrainingStage.VAL)), ]) + + if ratios_length == 2: + arr = np.concatenate( + [arr, np.full(test_num, SPLIT_TO_NUM.get(TrainingStage.TEST))]) + np.random.seed(seed) np.random.shuffle(arr) return arr + + +def validate_split_ratios(ratio: List[float]): + if len(ratio) > 2: + raise ValueError("No more than three training splits is supported") + if len(ratio) < 1: + raise ValueError("At least two training splits are required") + + for val in ratio: + if val < 0: + raise ValueError("'ratio' can not contain negative values") + + if sum(ratio) - 1 > 1e-2: + raise ValueError("'ratio' exceeds more than 100% of the data")