Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable streaming in data preprocessor #437

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

willmj
Copy link
Collaborator

@willmj willmj commented Jan 14, 2025

Description of the change

These changes enable streaming and test streaming datasets.
Added:

  • Add streaming as an arg in DataSetConfig similarly to sampling
  • Add examples of DataSetConfig in tests/artifacts/predefined_data_configs/ for streaming
  • Add unit tests
  • Since IterableDatasets can't be indexed, use first example where column names are needed
  • User must set max_steps instead of num_train_epochs if using streaming

Related issue number

How to verify the PR

  • Run new unit tests which verify HF inference works and passing streaming in dataconfig returns and IterableDataset
  • Run on single GPU error
  • Run on multi GPU without error

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

willmj added 3 commits January 8, 2025 15:31
…r future tests, add streaming to config

Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@github-actions github-actions bot added the feat label Jan 14, 2025
Comment on lines 352 to 354
# Which one? Should it be for user to decide or set?
# training_args.max_steps = training_args.num_train_epochs
training_args.max_steps = 1
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could do

        if training_args.num_train_epochs:
            training_args.max_steps = training_args.num_train_epochs
        else:
            training_args.max_steps = 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or if max_steps = training_size/batch_size could calculate it in the else statement like that

Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess for Iterable dataset, we don't have fixed length of complete dataset so it seems number of epochs doesn't work with it and hence needs max_steps.
So yea we could probably calculate max_steps using, training_size, batch_size_per_device, GA, num_train_epochs. Though I guess finding training_size might be difficult when data is just coming as a stream.

Would like to see what @dushyantbehl thinks...

Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @willmj for integrating usage of Iterable datasets. Just some initial thoughts.

tuning/data/data_processors.py Show resolved Hide resolved
Comment on lines 274 to 277
if isinstance(raw_dataset, Dataset):
raw_datasets[splitName] = raw_dataset
elif isinstance(raw_dataset, IterableDataset):
raw_datasets[splitName] = raw_dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we can merge if/else into:

if isinstance(raw_dataset, (Dataset, IterableDataset)):
    raw_datasets[splitName] = raw_dataset

Comment on lines 301 to 303
if not d.streaming:
if "num_proc" not in kwargs:
kwargs["num_proc"] = os.cpu_count()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume IterableDatasets is processed in single process and there is no multi-process processing support for IterableDatasets ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The map function for IterableDatasetDict does not take num_proc as an argument, so I'd assume yes.

Comment on lines 352 to 354
# Which one? Should it be for user to decide or set?
# training_args.max_steps = training_args.num_train_epochs
training_args.max_steps = 1
Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU Jan 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guess for Iterable dataset, we don't have fixed length of complete dataset so it seems number of epochs doesn't work with it and hence needs max_steps.
So yea we could probably calculate max_steps using, training_size, batch_size_per_device, GA, num_train_epochs. Though I guess finding training_size might be difficult when data is just coming as a stream.

Would like to see what @dushyantbehl thinks...

@ashokponkumar
Copy link
Collaborator

Shouldn't streaming be a top level object instead of a per dataset object? Is it possible to mix streaming and non-streaming datasets using concat?

@seshapad
Copy link
Contributor

@willmj Is this PR in a usable state? We need to run a EPT with large datasets. Without streaming the data processing is failing. We want the streaming feature to address this issue.

@kmehant
Copy link
Collaborator

kmehant commented Jan 26, 2025

@willmj I would request your attention to this

if "column_names" not in data or data.column_names is None:
if isinstance(data, IterableDataset):
if hasattr(data, "_resolve_features"):
data = data._resolve_features()
else:
raise ValueError(
"_resolve_features API is not available to fetch column names"
)
else:
raise ValueError(
f"not possible to fetch column names for the loaded dataset of type {type(data)}"
)
. iterabledatasets often loose out column information (sometimes on loading, or after map operations applied), so its good to be defensive on retrieving columns wherever necessary.

Signed-off-by: Will Johnson <[email protected]>
[
(
[TWITTER_COMPLAINTS_DATA_DIR_JSON],
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
Copy link
Collaborator

@Abhishek-TAMU Abhishek-TAMU Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this yaml file is to be used for the this test case: DATA_CONFIG_YAML_STREAMING

@willmj
Copy link
Collaborator Author

willmj commented Jan 28, 2025

@seshapad I have now had a successful tuning job with streaming on multi GPU. You should be able to try it out, let me know if you run into any errors.

Comment on lines 65 to 73
if isinstance(data, IterableDataset):
# For an IterableDataset, inspect the first element if possible
try:
first_example = next(iter(data))
return (
isinstance(first_example, dict)
and "input_ids" in first_example
and "labels" in first_example
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kmehant is this check sufficient to what you posted?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willmj
Can we have a unit test case with pretokenized + packing + streaming (common EPT case to cover)? Thank you.

Signed-off-by: Will Johnson <[email protected]>
@@ -332,6 +333,9 @@ def train(
time.time() - data_preprocessing_time
)

if isinstance(formatted_train_dataset, IterableDataset):
train_args.split_batches = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
train_args.split_batches = True
train_args.accelerator_config = {"split_batches": True}

Lets log this for user's awareness. Thanks

Comment on lines 103 to 104
if train_args.num_train_epochs:
logging.warning("`--num_train_epochs` will be overwritten by `--max_steps`")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since, this is the default behaviour of the HF transformers - https://github.com/huggingface/transformers/blob/ec7afad60909dd97d998c1f14681812d69a15728/src/transformers/trainer.py#L695

We can avoid this warning all together.

Signed-off-by: Will Johnson <[email protected]>
Signed-off-by: Will Johnson <[email protected]>
@willmj
Copy link
Collaborator Author

willmj commented Jan 29, 2025

Tuning + inference works! Only 200 steps so the equivalent of less than an epoch, which is why the result is wrong - but format is right.
Config:

      {
          "model_name_or_path": "/llama3/hf/8b_pre_trained",
          "data_config_path": "/testing/tuning/input/apply-custom-template-streaming-data-config.yaml",
          "output_dir": "/testing/tuning/output/llama3-8b/ft/tone_20250129_1045-streaming-dataconfig",
          "save_model_dir": "/testing/tuning/output/llama3-8b/ft/tone_20250129_1045-streaming-dataconfig/save_model",
          "max_steps": 200,
          "per_device_train_batch_size": 4,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-4,
          "response_template": "\n### Response:",
          "dataset_text_field": "output"
      }

Inference result on "Text: @sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas\n\n### Label:":

{
  "responses": [
    {
      "generatedTokenCount": 2,
      "text": " polite\u003c|end_of_text|\u003e",
      "inputTokenCount": 34,
      "stopReason": "EOS_TOKEN",
      "stopSequence": "\u003c|end_of_text|\u003e"
    }
  ]
}

@seshapad
Copy link
Contributor

seshapad commented Jan 30, 2025

@willmj The streaming option crashes. I have attached the log for debugging. Here is the data config:

dataprocessor:
    type: default
    sampling_stopping_strategy: all_exhausted
    seed: 66
    streaming: true
datasets:
  - name: pleias
    sampling: 1.0
    data_paths:
      - "/pleias_greek/"
    data_handlers:
      - name: apply_dataset_formatting
        arguments:
          remove_columns: ['source_directory', 'domain', 'document', 'subset', 'split', 'document_id', 'identifier', 'collection', 'license', '_meta_timestamp', '_meta_request_url', '_meta_final_url', '_meta_dataset', '_meta_job_id', '_meta_file_name', '_meta_json']
          fn_kwargs:
            dataset_text_field: "contents"

I can share the dataset with you if you wish to attempt reproducing this bug.
Configuration of cli used:

accelerate launch \
  --num_processes=8 \
  --dynamo_backend="no" \
  --fsdp_auto_wrap_policy="TRANSFORMER_BASED_WRAP" \
  --fsdp_cpu_ram_efficient_loading="true" \
  --fsdp_forward_prefetch="false" \
  --fsdp_offload_params="false" \
  --fsdp_sharding_strategy="HYBRID_SHARD" \
  --fsdp_state_dict_type="FULL_STATE_DICT" \
  --fsdp_sync_module_states="true" \
  --machine_rank="${RANK}" \
  --main_process_ip="${MASTER_ADDR}" \
  --main_process_port="${MASTER_PORT}" \
  --mixed_precision="no" \
  --num_machines="${WORLD_SIZE}" \
  --rdzv_backend="static" \
  --same_network \
  --use_fsdp \
  -m tuning.sft_trainer \
  --adam_beta2="0.95" \
  --aim_repo="${AIMSTACK_DB}" \
  --data_config="data_config.yaml" \
  --evaluation_strategy="no" \
  --experiment="train-nb-g8b-r18" \
  --gradient_accumulation_steps="1" \
  --gradient_checkpointing="true" \
  --include_tokens_per_second="true" \
  --learning_rate="0.0003" \
  --logging_steps="1" \
  --logging_strategy="steps" \
  --lr_scheduler_type="cosine" \
  --max_grad_norm="1" \
  --max_steps="100" \
  --model_name_or_path="ibm-granite/granite-3.1-8b-base" \
  --output_dir="/run18" \
  --packing="true" \
  --per_device_train_batch_size="8" \
  --save_steps="50" \
  --save_strategy="steps" \
  --split_batches="true" \
  --torch_dtype="bfloat16" \
  --tracker="aim" \
  --use_flash_attn="true" \
  --warmup_ratio="0.05" \
  --weight_decay="0.1" \
  2>&1 | tee -a "/run18/accelerate_launch_output.log"

cc: @ashokponkumar

… pretokenized case in data collator

Signed-off-by: Will Johnson <[email protected]>
@@ -51,6 +52,22 @@ def load_yaml_or_json(file_path: str) -> dict:
return None


def resolve_iterable_dataset_features(data: IterableDataset):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be better to make this function part of the DataPreProcessor class?

# Note that this automatically pads labels with -100
# TODO check if this is sufficient for preprocessed
# For EPT case where they need packing
if is_traindata_tokenized:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For EPT case where we want tokenized with packing

@willmj
Copy link
Collaborator Author

willmj commented Feb 4, 2025

Model trained on single + multi GPU, getting decent results:

      {
          "model_name_or_path": "/llama3/hf/8b_pre_trained",
          "data_config_path": "/testing/tuning/input/apply-custom-template-streaming-data-config-10-datasets.yaml",
          "output_dir": "/testing/tuning/output/llama3-8b/ft/tone_20250204_1505-streaming-dataconfig-10-datasets",
          "save_model_dir": "/testing/tuning/output/llama3-8b/ft/tone_20250204_1505-streaming-dataconfig-10-datasets/save_model",
          "max_steps": 2200,
          "per_device_train_batch_size": 4,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-4,
          "response_template": "\n### Response:",
          "dataset_text_field": "output"
      }

Model location:/testing/tuning/output/llama3-8b/ft/tone_20250204_1505-streaming-dataconfig-10-datasets/save_model

Micro F1 score:

"accuracy": 0.308,
"f1": {
    "macro": 0.29570179708154787,
    "micro": 0.4507888805409467
},

@willmj willmj closed this Feb 4, 2025
@willmj willmj reopened this Feb 4, 2025
@willmj willmj marked this pull request as ready for review February 4, 2025 21:58
@dushyantbehl
Copy link
Contributor

Closes #233

@@ -332,6 +333,10 @@ def train(
time.time() - data_preprocessing_time
)

if isinstance(formatted_train_dataset, IterableDataset):
train_args.accelerator_config = {"split_batches": True}
logger.info("Setting split_batches to true - splitting batches among devices")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willmj an important point to be captured in the log is that, since split_batches is being set to True the message should also capture that per_device_train_batch_size would now be global batch size and no more holds that meaning.

@@ -233,20 +255,29 @@ def _process_dataset_configs(
)
sample_datasets = False

streaming = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand, any reason for introducing this new variable instead of reusing self.processor_config.streaming?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants