-
Notifications
You must be signed in to change notification settings - Fork 59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Split] random split method improvement #353
Conversation
seed (int, optional): The seed for the randomness generator. | ||
|
||
Returns: | ||
A np.ndarra object representing the split. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
np.ndarray
split = generate_random_split(num_data, seed=42, | ||
ratios=[train_ratio, val_ratio]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we want this to be a list? I think the previous argument is clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This format is more aligned with split method in torch https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#random_split. And that to support split with only train and validation data with the new format, we are able to do split(ratios=[0.8], ...)
while with the old format, we could only do split(train_ratio=0.8, val_ratio=None, ...)
which is a bit weird, because this indicates a train/test split instead of a train/val split which is more common.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah but here we are actually assigning TRAIN to the first ratio, VAL to the second ratio, and so on. So I still think the previous one is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch's random split makes sense since it just splits the dataset. There is no train/val/test assignment happening under the hood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the previous format we always assume we want to have a test set. So
- it does not have the flexibility to doing something like (train: 0.3, val: 0.1, test: 0.2)
- when we do train: 0.8, val: 0.0. Then the inferred test ratio is 0.2. Which is a bit weird because in this case only train and val (test) split is well-defined. i.e. we only intend to split the data into 2 parts instead of 3 parts. But we end up with an additional validation set that is completely empty
We could also keep the previous format but support the case where val ratio could be None
which indicates there is no validation set instead of an empty validation set if you have strong opinion here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand you. but I still want features based on needs.
I don't find these logics useful for now, and having new interface would break other parts. For instance, this PR does not modify generate_random_split
used in data_frame_benchmark.py
and elsewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance, this PR does not modify generate_random_split used in data_frame_benchmark.py and elsewhere
This is still a draft now. I kind of want to sync with you on this first before changing other parts of the code to make them compatible.
def random_split(self, ratios: list[float] | None = None): | ||
split = generate_random_split(self.num_rows, ratios) | ||
self.split_col = 'split' | ||
self.df[self.split_col] = split |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be nice to supply an argument to control the random seed of the random seed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, add doc-string so that people know how to use it.
@@ -28,6 +28,12 @@ def supports_task_type(self, task_type: 'TaskType') -> bool: | |||
return self in task_type.supported_metrics | |||
|
|||
|
|||
class TrainingStage(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Brief doc-string.
torch_frame/utils/split.py
Outdated
validate_split_ratios(ratios) | ||
ratios_length = len(ratios) | ||
if length < ratios_length + 1: | ||
raise ValueError( | ||
f"We want to split data into {ratios_length + 1} disjoint set. " | ||
f"However data contains {length} data point. Consider " | ||
f"increase your data size.") | ||
|
||
# train_num = int(length * train_ratio) | ||
# val_num = int(length * val_ratio) | ||
# test_num = length - train_num - val_num | ||
train_num = math.floor(length * ratios[0]) | ||
val_num = math.floor( | ||
length * ratios[1]) if ratios_length == 2 else length - train_num | ||
test_num = None | ||
if ratios_length == 2: | ||
test_num = length - train_num - val_num | ||
|
||
arr = np.concatenate([ | ||
np.full(train_num, SPLIT_TO_NUM['train']), | ||
np.full(val_num, SPLIT_TO_NUM['val']), | ||
np.full(test_num, SPLIT_TO_NUM['test']) | ||
np.full(train_num, SPLIT_TO_NUM.get(TrainingStage.TRAIN)), | ||
np.full(val_num, SPLIT_TO_NUM.get(TrainingStage.VAL)), | ||
]) | ||
|
||
if ratios_length == 2: | ||
arr = np.concatenate( | ||
[arr, np.full(test_num, SPLIT_TO_NUM.get(TrainingStage.TEST))]) | ||
|
||
np.random.seed(seed) | ||
np.random.shuffle(arr) | ||
|
||
return arr | ||
|
||
|
||
def validate_split_ratios(ratio: list[float]): | ||
if len(ratio) > 2: | ||
raise ValueError("No more than three training splits is supported") | ||
if len(ratio) < 1: | ||
raise ValueError("At least two training splits are required") | ||
|
||
for val in ratio: | ||
if val < 0: | ||
raise ValueError("'ratio' can not contain negative values") | ||
|
||
if sum(ratio) - 1 > 1e-2: | ||
raise ValueError("'ratio' exceeds more than 100% of the data") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify what's wrong with the old implementation? I think the old one is simpler.
@@ -695,7 +696,7 @@ def col_select(self, cols: ColumnSelectType) -> Dataset: | |||
|
|||
return dataset | |||
|
|||
def get_split(self, split: str) -> Dataset: | |||
def get_split(self, split: TrainingStage) -> Dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's nicer to accept str
as input here. It's sometimes troublesome to import TrainingStage
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can support both? i.e. class TrainingStage(str, Enum)
I feel like TrainingStage
has its advantage in that no mistypes will occur that causes unexpected behavior, e.g. 'Train' vs 'train'. I feel like this is also the reason we support lots of Enum classes in the codebase
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can support both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Union[str, TrainingStage]
Thanks for the effort. I am not sure the additional complexity is worthwhile though. I think the generalization to |
For #352