Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Split] random split method improvement #353

Closed
wants to merge 2 commits into from
Closed

Conversation

XinweiHe
Copy link
Contributor

For #352

seed (int, optional): The seed for the randomness generator.

Returns:
A np.ndarra object representing the split.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.ndarray

Comment on lines +14 to +15
split = generate_random_split(num_data, seed=42,
ratios=[train_ratio, val_ratio])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we want this to be a list? I think the previous argument is clearer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This format is more aligned with split method in torch https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#random_split. And that to support split with only train and validation data with the new format, we are able to do split(ratios=[0.8], ...) while with the old format, we could only do split(train_ratio=0.8, val_ratio=None, ...) which is a bit weird, because this indicates a train/test split instead of a train/val split which is more common.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but here we are actually assigning TRAIN to the first ratio, VAL to the second ratio, and so on. So I still think the previous one is better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch's random split makes sense since it just splits the dataset. There is no train/val/test assignment happening under the hood.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous format we always assume we want to have a test set. So

  1. it does not have the flexibility to doing something like (train: 0.3, val: 0.1, test: 0.2)
  2. when we do train: 0.8, val: 0.0. Then the inferred test ratio is 0.2. Which is a bit weird because in this case only train and val (test) split is well-defined. i.e. we only intend to split the data into 2 parts instead of 3 parts. But we end up with an additional validation set that is completely empty

We could also keep the previous format but support the case where val ratio could be None which indicates there is no validation set instead of an empty validation set if you have strong opinion here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you. but I still want features based on needs.
I don't find these logics useful for now, and having new interface would break other parts. For instance, this PR does not modify generate_random_split used in data_frame_benchmark.py and elsewhere.

Copy link
Contributor Author

@XinweiHe XinweiHe Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, this PR does not modify generate_random_split used in data_frame_benchmark.py and elsewhere

This is still a draft now. I kind of want to sync with you on this first before changing other parts of the code to make them compatible.

Comment on lines +726 to +729
def random_split(self, ratios: list[float] | None = None):
split = generate_random_split(self.num_rows, ratios)
self.split_col = 'split'
self.df[self.split_col] = split
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be nice to supply an argument to control the random seed of the random seed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, add doc-string so that people know how to use it.

@@ -28,6 +28,12 @@ def supports_task_type(self, task_type: 'TaskType') -> bool:
return self in task_type.supported_metrics


class TrainingStage(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brief doc-string.

Comment on lines 40 to 84
validate_split_ratios(ratios)
ratios_length = len(ratios)
if length < ratios_length + 1:
raise ValueError(
f"We want to split data into {ratios_length + 1} disjoint set. "
f"However data contains {length} data point. Consider "
f"increase your data size.")

# train_num = int(length * train_ratio)
# val_num = int(length * val_ratio)
# test_num = length - train_num - val_num
train_num = math.floor(length * ratios[0])
val_num = math.floor(
length * ratios[1]) if ratios_length == 2 else length - train_num
test_num = None
if ratios_length == 2:
test_num = length - train_num - val_num

arr = np.concatenate([
np.full(train_num, SPLIT_TO_NUM['train']),
np.full(val_num, SPLIT_TO_NUM['val']),
np.full(test_num, SPLIT_TO_NUM['test'])
np.full(train_num, SPLIT_TO_NUM.get(TrainingStage.TRAIN)),
np.full(val_num, SPLIT_TO_NUM.get(TrainingStage.VAL)),
])

if ratios_length == 2:
arr = np.concatenate(
[arr, np.full(test_num, SPLIT_TO_NUM.get(TrainingStage.TEST))])

np.random.seed(seed)
np.random.shuffle(arr)

return arr


def validate_split_ratios(ratio: list[float]):
if len(ratio) > 2:
raise ValueError("No more than three training splits is supported")
if len(ratio) < 1:
raise ValueError("At least two training splits are required")

for val in ratio:
if val < 0:
raise ValueError("'ratio' can not contain negative values")

if sum(ratio) - 1 > 1e-2:
raise ValueError("'ratio' exceeds more than 100% of the data")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you clarify what's wrong with the old implementation? I think the old one is simpler.

@@ -695,7 +696,7 @@ def col_select(self, cols: ColumnSelectType) -> Dataset:

return dataset

def get_split(self, split: str) -> Dataset:
def get_split(self, split: TrainingStage) -> Dataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's nicer to accept str as input here. It's sometimes troublesome to import TrainingStage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can support both? i.e. class TrainingStage(str, Enum) I feel like TrainingStage has its advantage in that no mistypes will occur that causes unexpected behavior, e.g. 'Train' vs 'train'. I feel like this is also the reason we support lots of Enum classes in the codebase

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can support both.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Union[str, TrainingStage]

@weihua916
Copy link
Contributor

Thanks for the effort. I am not sure the additional complexity is worthwhile though. I think the generalization to ratios: list[float] is not necessary here since each element has its role of train/val/test ratios.

@XinweiHe XinweiHe closed this May 14, 2024
@akihironitta akihironitta deleted the xinwei_split_v1 branch May 16, 2024 13:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants