-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Enable streaming in data preprocessor #437
base: main
Are you sure you want to change the base?
feat: Enable streaming in data preprocessor #437
Conversation
…r future tests, add streaming to config Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Thanks for making a pull request! 😃 |
Signed-off-by: Will Johnson <[email protected]>
tuning/sft_trainer.py
Outdated
# Which one? Should it be for user to decide or set? | ||
# training_args.max_steps = training_args.num_train_epochs | ||
training_args.max_steps = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could do
if training_args.num_train_epochs:
training_args.max_steps = training_args.num_train_epochs
else:
training_args.max_steps = 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or if max_steps = training_size/batch_size
could calculate it in the else statement like that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess for Iterable dataset
, we don't have fixed length of complete dataset so it seems number of epochs doesn't work with it and hence needs max_steps
.
So yea we could probably calculate max_steps
using, training_size, batch_size_per_device, GA, num_train_epochs
. Though I guess finding training_size
might be difficult when data is just coming as a stream.
Would like to see what @dushyantbehl thinks...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @willmj for integrating usage of Iterable datasets. Just some initial thoughts.
tuning/data/data_processors.py
Outdated
if isinstance(raw_dataset, Dataset): | ||
raw_datasets[splitName] = raw_dataset | ||
elif isinstance(raw_dataset, IterableDataset): | ||
raw_datasets[splitName] = raw_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably we can merge if/else into:
if isinstance(raw_dataset, (Dataset, IterableDataset)):
raw_datasets[splitName] = raw_dataset
tuning/data/data_processors.py
Outdated
if not d.streaming: | ||
if "num_proc" not in kwargs: | ||
kwargs["num_proc"] = os.cpu_count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume IterableDatasets
is processed in single process and there is no multi-process
processing support for IterableDatasets
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The map
function for IterableDatasetDict
does not take num_proc
as an argument, so I'd assume yes.
tuning/sft_trainer.py
Outdated
# Which one? Should it be for user to decide or set? | ||
# training_args.max_steps = training_args.num_train_epochs | ||
training_args.max_steps = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guess for Iterable dataset
, we don't have fixed length of complete dataset so it seems number of epochs doesn't work with it and hence needs max_steps
.
So yea we could probably calculate max_steps
using, training_size, batch_size_per_device, GA, num_train_epochs
. Though I guess finding training_size
might be difficult when data is just coming as a stream.
Would like to see what @dushyantbehl thinks...
Shouldn't streaming be a top level object instead of a per dataset object? Is it possible to mix streaming and non-streaming datasets using concat? |
…nstead of dataset config Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
@willmj Is this PR in a usable state? We need to run a EPT with large datasets. Without streaming the data processing is failing. We want the streaming feature to address this issue. |
@willmj I would request your attention to this fms-hf-tuning/tuning/utils/preprocessing_utils.py Lines 49 to 60 in 224f35b
map operations applied), so its good to be defensive on retrieving columns wherever necessary.
|
Signed-off-by: Will Johnson <[email protected]>
tests/test_sft_trainer.py
Outdated
[ | ||
( | ||
[TWITTER_COMPLAINTS_DATA_DIR_JSON], | ||
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this yaml file is to be used for the this test case: DATA_CONFIG_YAML_STREAMING
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
@seshapad I have now had a successful tuning job with streaming on multi GPU. You should be able to try it out, let me know if you run into any errors. |
tuning/data/setup_dataprocessor.py
Outdated
if isinstance(data, IterableDataset): | ||
# For an IterableDataset, inspect the first element if possible | ||
try: | ||
first_example = next(iter(data)) | ||
return ( | ||
isinstance(first_example, dict) | ||
and "input_ids" in first_example | ||
and "labels" in first_example | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kmehant is this check sufficient to what you posted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@willmj
Can we have a unit test case with pretokenized + packing + streaming (common EPT case to cover)? Thank you.
Signed-off-by: Will Johnson <[email protected]>
tuning/sft_trainer.py
Outdated
@@ -332,6 +333,9 @@ def train( | |||
time.time() - data_preprocessing_time | |||
) | |||
|
|||
if isinstance(formatted_train_dataset, IterableDataset): | |||
train_args.split_batches = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train_args.split_batches = True | |
train_args.accelerator_config = {"split_batches": True} |
Lets log this for user's awareness. Thanks
tuning/data/setup_dataprocessor.py
Outdated
if train_args.num_train_epochs: | ||
logging.warning("`--num_train_epochs` will be overwritten by `--max_steps`") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since, this is the default behaviour of the HF transformers - https://github.com/huggingface/transformers/blob/ec7afad60909dd97d998c1f14681812d69a15728/src/transformers/trainer.py#L695
We can avoid this warning all together.
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Tuning + inference works! Only 200 steps so the equivalent of less than an epoch, which is why the result is wrong - but format is right.
Inference result on
|
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
@willmj The
I can share the dataset with you if you wish to attempt reproducing this bug.
cc: @ashokponkumar |
… pretokenized case in data collator Signed-off-by: Will Johnson <[email protected]>
@@ -51,6 +52,22 @@ def load_yaml_or_json(file_path: str) -> dict: | |||
return None | |||
|
|||
|
|||
def resolve_iterable_dataset_features(data: IterableDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be better to make this function part of the DataPreProcessor
class?
Signed-off-by: Will Johnson <[email protected]>
# Note that this automatically pads labels with -100 | ||
# TODO check if this is sufficient for preprocessed | ||
# For EPT case where they need packing | ||
if is_traindata_tokenized: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For EPT case where we want tokenized with packing
Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Model trained on single + multi GPU, getting decent results:
Model location: Micro F1 score:
|
Closes #233 |
@@ -332,6 +333,10 @@ def train( | |||
time.time() - data_preprocessing_time | |||
) | |||
|
|||
if isinstance(formatted_train_dataset, IterableDataset): | |||
train_args.accelerator_config = {"split_batches": True} | |||
logger.info("Setting split_batches to true - splitting batches among devices") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@willmj an important point to be captured in the log is that, since split_batches is being set to True the message should also capture that per_device_train_batch_size would now be global batch size and no more holds that meaning.
@@ -233,20 +255,29 @@ def _process_dataset_configs( | |||
) | |||
sample_datasets = False | |||
|
|||
streaming = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand, any reason for introducing this new variable instead of reusing self.processor_config.streaming
?
Description of the change
These changes enable streaming and test streaming datasets.
Added:
streaming
as an arg inDataSetConfig
similarly tosampling
IterableDatasets
can't be indexed, use first example where column names are neededmax_steps
instead ofnum_train_epochs
if using streamingRelated issue number
How to verify the PR
streaming
in dataconfig returns and IterableDatasetWas the PR tested