diff --git a/train.py b/train.py index 4299aaa..65d57e0 100644 --- a/train.py +++ b/train.py @@ -66,7 +66,7 @@ def main() -> None: print("Getting data...") batch_size = data_config.batch_size dataset, x_shape, y_shape = prepare_dataset(task, data_config, training=True) - dataset = dataset.shuffle(5000, reshuffle_each_iteration=False) + dataset = dataset.shuffle(1000000, reshuffle_each_iteration=False) val_len = eval_config.val_len dataset = dataset.skip(val_len)