From c65e20231940417296b41a01647997175277fe9f Mon Sep 17 00:00:00 2001 From: Daniel Fleischer Date: Mon, 19 Aug 2024 15:26:05 +0300 Subject: [PATCH] Global step for merging datasets (#9) Step uses the `inputs` key as the datasets to be combined, `output` as the new dataset name and can optionally shuffle the resulting dataset. --- .../processing/global_steps/aggregation.md | 1 + mkdocs.yml | 1 + .../processing/global_steps/aggregation.py | 31 +++++++++++++++++++ 3 files changed, 33 insertions(+) create mode 100644 docs/reference/processing/global_steps/aggregation.md create mode 100644 ragfoundry/processing/global_steps/aggregation.py diff --git a/docs/reference/processing/global_steps/aggregation.md b/docs/reference/processing/global_steps/aggregation.md new file mode 100644 index 0000000..737cc99 --- /dev/null +++ b/docs/reference/processing/global_steps/aggregation.md @@ -0,0 +1 @@ +::: ragfoundry.processing.global_steps.aggregation \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 5324c13..3c55786 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -177,6 +177,7 @@ nav: - Prompt Creation: "reference/processing/local_steps/prompter.md" - RAFT: "reference/processing/local_steps/raft.md" - Global Steps: + - Aggregation and merging: "reference/processing/global_steps/aggregation.md" - Sampling and Fewshot: "reference/processing/global_steps/sampling.md" - Output: "reference/processing/global_steps/output.md" - Answer Processors: diff --git a/ragfoundry/processing/global_steps/aggregation.py b/ragfoundry/processing/global_steps/aggregation.py new file mode 100644 index 0000000..b53f873 --- /dev/null +++ b/ragfoundry/processing/global_steps/aggregation.py @@ -0,0 +1,31 @@ +from datasets import concatenate_datasets + +from ..step import GlobalStep + + +class MergeDatasets(GlobalStep): + """ + Step for merging datasets. + + Merge is done using concatenation. Optional shuffling by providing a seed. + """ + + def __init__(self, output, shuffle=None, **kwargs): + """ + Args: + output (str): Name of the output dataset. Should be unique. + shuffle (int, optional): seed for shuffling. Default is None. + """ + super().__init__(**kwargs) + self.output = output + self.shuffle = shuffle + self.completed = False + self.cache_step = False + + def process(self, dataset_name, datasets, **kwargs): + if not self.completed: + data = concatenate_datasets([datasets[name] for name in self.inputs]) + if self.shuffle: + data = data.shuffle(self.shuffle) + datasets[self.output] = data + self.completed = True