Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable streaming in data preprocessor #437

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tests/artifacts/predefined_data_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
)
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
)
DATA_CONFIG_YAML_STREAMING_PRETOKENIZED = os.path.join(
PREDEFINED_DATA_CONFIGS, "pretokenized_json_data_streaming.yaml"
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataprocessor:
type: default
streaming: true
datasets:
- name: apply_custom_data_template
data_paths:
- "FILE_PATH"
data_handlers:
- name: apply_custom_data_formatting_template
arguments:
remove_columns: all
batched: false
fn_kwargs:
dataset_text_field: "dataset_text_field"
template: "dataset_template"
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dataprocessor:
type: default
streaming: true
datasets:
- name: pretokenized_dataset
data_paths:
- "FILE_PATH"
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
dataprocessor:
type: default
streaming: true
datasets:
- name: text_dataset_input_output_masking
data_paths:
- "FILE_PATH"
data_handlers:
- name: tokenize_and_apply_input_masking
arguments:
remove_columns: all
batched: false
fn_kwargs:
input_field_name: input
output_field_name: output
29 changes: 29 additions & 0 deletions tests/data/test_data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,35 @@ def test_apply_custom_formatting_template():
assert formatted_dataset["train"][0][formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_iterable():
json_dataset = datasets.load_dataset(
"json", data_files=TWITTER_COMPLAINTS_DATA_JSONL, streaming=True
)
template = "### Input: {{Tweet text}} \n\n ### Response: {{text_label}}"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
formatted_dataset_field = "formatted_data_field"
formatted_dataset = json_dataset.map(
apply_custom_data_formatting_template,
fn_kwargs={
"tokenizer": tokenizer,
"dataset_text_field": formatted_dataset_field,
"template": template,
},
)
# First response from the data file that is read.
expected_response = (
"### Input: @HMRCcustomers No this is my first job"
+ " \n\n ### Response: no complaint"
+ tokenizer.eos_token
)

first_sample = next(iter(formatted_dataset["train"]))

# a new dataset_text_field is created in Dataset
assert formatted_dataset_field in first_sample
assert first_sample[formatted_dataset_field] == expected_response


def test_apply_custom_formatting_template_gives_error_with_wrong_keys():
"""Tests that the formatting function will throw error if wrong keys are passed to template"""
json_dataset = datasets.load_dataset(
Expand Down
Loading