-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Split] random split method improvement #353
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,8 +33,9 @@ | |
IndexSelectType, | ||
TaskType, | ||
TensorData, | ||
TrainingStage, | ||
) | ||
from torch_frame.utils.split import SPLIT_TO_NUM | ||
from torch_frame.utils.split import SPLIT_TO_NUM, generate_random_split | ||
|
||
COL_TO_PATTERN_STYPE_MAPPING = { | ||
"col_to_sep": torch_frame.multicategorical, | ||
|
@@ -695,7 +696,7 @@ def col_select(self, cols: ColumnSelectType) -> Dataset: | |
|
||
return dataset | ||
|
||
def get_split(self, split: str) -> Dataset: | ||
def get_split(self, split: TrainingStage) -> Dataset: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's nicer to accept There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can support both? i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we can support both. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
r"""Returns a subset of the dataset that belongs to a given training | ||
split (as defined in :obj:`split_col`). | ||
|
||
|
@@ -707,20 +708,25 @@ def get_split(self, split: str) -> Dataset: | |
raise ValueError( | ||
f"'get_split' is not supported for '{self}' since 'split_col' " | ||
f"is not specified.") | ||
if split not in ["train", "val", "test"]: | ||
raise ValueError(f"The split named '{split}' is not available. " | ||
f"Needs to be either 'train', 'val', or 'test'.") | ||
indices = self.df.index[self.df[self.split_col] == | ||
SPLIT_TO_NUM[split]].tolist() | ||
return self[indices] | ||
|
||
def split(self) -> tuple[Dataset, Dataset, Dataset]: | ||
r"""Splits the dataset into training, validation and test splits.""" | ||
return ( | ||
self.get_split("train"), | ||
self.get_split("val"), | ||
self.get_split("test"), | ||
) | ||
def split(self) -> tuple[Dataset, Dataset, Dataset | None]: | ||
r"""Splits the dataset into training, validation and optionally | ||
test splits. | ||
""" | ||
train_set = self.get_split(TrainingStage.TRAIN) | ||
val_set = self.get_split(TrainingStage.VAL) | ||
test_set = self.get_split(TrainingStage.TEST) | ||
if test_set.num_rows == 0: | ||
test_set = None | ||
return train_set, val_set, test_set | ||
|
||
def random_split(self, ratios: list[float] | None = None): | ||
split = generate_random_split(self.num_rows, ratios) | ||
self.split_col = 'split' | ||
self.df[self.split_col] = split | ||
Comment on lines
+726
to
+729
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it'd be nice to supply an argument to control the random seed of the random seed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, add doc-string so that people know how to use it. |
||
|
||
@property | ||
@requires_post_materialization | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,12 @@ def supports_task_type(self, task_type: 'TaskType') -> bool: | |
return self in task_type.supported_metrics | ||
|
||
|
||
class TrainingStage(Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Brief doc-string. |
||
TRAIN = 'train' | ||
VAL = 'val' | ||
TEST = 'test' | ||
|
||
|
||
class TaskType(Enum): | ||
r"""The type of the task. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,85 @@ | ||
import math | ||
|
||
import numpy as np | ||
|
||
from typing import List | ||
from torch_frame.typing import TrainingStage | ||
|
||
# Mapping split name to integer. | ||
SPLIT_TO_NUM = {'train': 0, 'val': 1, 'test': 2} | ||
SPLIT_TO_NUM = { | ||
TrainingStage.TRAIN: 0, | ||
TrainingStage.VAL: 1, | ||
TrainingStage.TEST: 2 | ||
} | ||
|
||
|
||
def generate_random_split(length: int, seed: int, train_ratio: float = 0.8, | ||
val_ratio: float = 0.1) -> np.ndarray: | ||
def generate_random_split( | ||
length: int, | ||
ratios: List[float], | ||
seed: int = 0, | ||
) -> np.ndarray: | ||
r"""Generate a list of random split assignments of the specified length. | ||
The elements are either :obj:`0`, :obj:`1`, or :obj:`2`, representing | ||
train, val, test, respectively. Note that this function relies on the fact | ||
that numpy's shuffle is consistent across versions, which has been | ||
historically the case. | ||
|
||
Args: | ||
length (int): The length of the dataset. | ||
ratios (List[float]): Ratios for split assignment. When ratios | ||
contains 2 variables, we will generate train/val/test set | ||
respectively based on the split ratios (the 1st variable in | ||
the list will be the ratio for train set, the 2nd will be | ||
the ratio for val set and the remaining data will be used | ||
for test set). When ratios contains 1 variable, we will only | ||
generate train/val set. (the variable in) | ||
seed (int, optional): The seed for the randomness generator. | ||
|
||
Returns: | ||
A np.ndarra object representing the split. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
""" | ||
assert train_ratio + val_ratio < 1 | ||
assert train_ratio > 0 | ||
assert val_ratio > 0 | ||
train_num = int(length * train_ratio) | ||
val_num = int(length * val_ratio) | ||
test_num = length - train_num - val_num | ||
validate_split_ratios(ratios) | ||
ratios_length = len(ratios) | ||
if length < ratios_length + 1: | ||
raise ValueError( | ||
f"We want to split data into {ratios_length + 1} disjoint set. " | ||
f"However data contains {length} data point. Consider " | ||
f"increase your data size.") | ||
|
||
# train_num = int(length * train_ratio) | ||
# val_num = int(length * val_ratio) | ||
# test_num = length - train_num - val_num | ||
train_num = math.floor(length * ratios[0]) | ||
val_num = math.floor( | ||
length * ratios[1]) if ratios_length == 2 else length - train_num | ||
test_num = None | ||
if ratios_length == 2: | ||
test_num = length - train_num - val_num | ||
|
||
arr = np.concatenate([ | ||
np.full(train_num, SPLIT_TO_NUM['train']), | ||
np.full(val_num, SPLIT_TO_NUM['val']), | ||
np.full(test_num, SPLIT_TO_NUM['test']) | ||
np.full(train_num, SPLIT_TO_NUM.get(TrainingStage.TRAIN)), | ||
np.full(val_num, SPLIT_TO_NUM.get(TrainingStage.VAL)), | ||
]) | ||
|
||
if ratios_length == 2: | ||
arr = np.concatenate( | ||
[arr, np.full(test_num, SPLIT_TO_NUM.get(TrainingStage.TEST))]) | ||
|
||
np.random.seed(seed) | ||
np.random.shuffle(arr) | ||
|
||
return arr | ||
|
||
|
||
def validate_split_ratios(ratio: List[float]): | ||
if len(ratio) > 2: | ||
raise ValueError("No more than three training splits is supported") | ||
if len(ratio) < 1: | ||
raise ValueError("At least two training splits are required") | ||
|
||
for val in ratio: | ||
if val < 0: | ||
raise ValueError("'ratio' can not contain negative values") | ||
|
||
if sum(ratio) - 1 > 1e-2: | ||
raise ValueError("'ratio' exceeds more than 100% of the data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we want this to be a list? I think the previous argument is clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This format is more aligned with split method in torch https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#random_split. And that to support split with only train and validation data with the new format, we are able to do
split(ratios=[0.8], ...)
while with the old format, we could only dosplit(train_ratio=0.8, val_ratio=None, ...)
which is a bit weird, because this indicates a train/test split instead of a train/val split which is more common.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but here we are actually assigning TRAIN to the first ratio, VAL to the second ratio, and so on. So I still think the previous one is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch's random split makes sense since it just splits the dataset. There is no train/val/test assignment happening under the hood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the previous format we always assume we want to have a test set. So
We could also keep the previous format but support the case where val ratio could be
None
which indicates there is no validation set instead of an empty validation set if you have strong opinion here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand you. but I still want features based on needs.
I don't find these logics useful for now, and having new interface would break other parts. For instance, this PR does not modify
generate_random_split
used indata_frame_benchmark.py
and elsewhere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still a draft now. I kind of want to sync with you on this first before changing other parts of the code to make them compatible.