Skip to content

Commit

Permalink
New Features (#14)
Browse files Browse the repository at this point in the history
- RecallEM metric.
- Aggregation steps: filtering, column selection, tagging, value overwrite.
- Local inference step using vLLM; can generate synthetic datasets.
- Some minor modification of the QA system instructions.
- Ruff configuration file.
- Evaluation split in the training script.
  • Loading branch information
danielfleischer authored Nov 12, 2024
1 parent f21cd32 commit b5ed97f
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 20 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
/.python-version
/outputs/
__pycache__/
/site/
/site/
/multirun/
wandb
.ipynb_checkpoints
40 changes: 40 additions & 0 deletions configs/processing-nq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: nq
cache: false
output_path: .
steps:
- _target_: ragfit.processing.dataset_loaders.loaders.HFLoader
inputs: train
dataset_config:
path: Tevatron/wikipedia-nq
split: train

- _target_: ragfit.processing.global_steps.sampling.ShuffleSelect
inputs: train
shuffle: 42
limit: 10000

- _target_: ragfit.processing.local_steps.prompter.TextPrompter
inputs: train
prompt_file: ragfit/processing/prompts/qa-short.txt
output_key: prompt
mapping:
query: query

- _target_: ragfit.processing.local_steps.inference.HFStep
inputs: train
input_key: prompt
output_key: generated
model_kwargs:
model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
instruction: ragfit/processing/prompts/prompt_instructions/qa-short.txt
num_gpus: 2
llm_params:
dtype: auto
max_model_len: 4096
generation:
temperature: 0
max_tokens: 50

- _target_: ragfit.processing.global_steps.output.OutputData
inputs: train
prefix: nq-with-answers
1 change: 1 addition & 0 deletions docs/reference/processing/global_steps/filters.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragfit.processing.global_steps.filters
1 change: 1 addition & 0 deletions docs/reference/processing/local_steps/inference.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: ragfit.processing.local_steps.inference
11 changes: 7 additions & 4 deletions evaluation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
from collections import defaultdict
from pathlib import Path

import hydra
import torch
Expand Down Expand Up @@ -93,10 +94,12 @@ def map_load(example, idx):
if args.use_wandb:
run.log(results, step=0)

if args.results_file:
with open(args.results_file, "w") as f:
yaml.dump(results, f, sort_keys=True)
logging.info(f"Results saved to {args.results_file}")
if args.results_file is None:
args.results_file = Path(args.generated_file).stem + "-results.yaml"

with open(args.results_file, "w") as f:
yaml.dump(results, f, sort_keys=True)
logging.info(f"Results saved to {args.results_file}")


if __name__ == "__main__":
Expand Down
6 changes: 0 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,3 @@ haystack = [
"qdrant-haystack>=5.0.0",
]

[tool.ruff]
line-length = 90

[tool.ruff.lint]
select = ["E", "F", "W", "I", "N", "Q"]
ignore = ["E203", "F841", "E501"]
81 changes: 78 additions & 3 deletions ragfit/evaluation/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import re
import string
import unicodedata
from collections import Counter, defaultdict

import regex

from .base import MetricBase


Expand Down Expand Up @@ -71,16 +74,21 @@ def __init__(
self.precision_recall_fn = precision_recall_fscore_support
self.accuracy_fn = accuracy_score

def in_text(self, text):
if "yes" in text:
return 1
if "no" in text:
return 0
return 2

def measure(self, example: dict):
inputs = example[self.field]
targets = example[self.target]

if isinstance(targets[0], list):
targets = [t[0] for t in targets]

inputs = [
self.mapping.get(normalize_text(i).strip(), self.else_value) for i in inputs
]
inputs = [self.in_text(normalize_text(i).strip()) for i in inputs]

targets = [
self.mapping.get(normalize_text(t).strip(), self.else_value) for t in targets
Expand Down Expand Up @@ -222,6 +230,73 @@ def measure(self, example: dict):
return {"StringEM": sum(scores) / len(scores)}


class SimpleTokenizer(object):
ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+"
NON_WS = r"[^\p{Z}\p{C}]"

def __init__(self):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self._regexp = regex.compile(
"(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE,
)

def tokenize(self, text, uncased=False):
matches = [m for m in self._regexp.finditer(text)]
if uncased:
tokens = [m.group().lower() for m in matches]
else:
tokens = [m.group() for m in matches]
return tokens


class RecallEM(MetricBase):
"""
Implementing EM as in XRAG.
"""

def __init__(self, key_names, **kwargs) -> None:
"""Initialize the Metrics class.
Args:
key_names (dict): A dictionary containing the field names.
"""
super().__init__(key_names, **kwargs)
self.local = True

@staticmethod
def _normalize(text):
return unicodedata.normalize("NFD", text)

def has_answer(self, answers, text, tokenizer=SimpleTokenizer()):
"""Check if a document contains an answer string."""
text = self._normalize(text)
text = tokenizer.tokenize(text, uncased=True)

for answer in answers:
answer = self._normalize(answer)
answer = tokenizer.tokenize(answer, uncased=True)
for i in range(0, len(text) - len(answer) + 1):
if answer == text[i : i + len(answer)]:
return True
return False

def measure(self, example: dict):
input = example[self.field]
target = example[self.target]

assert isinstance(input, str), f"Generated text should be a string: {input}"

if not isinstance(target, list):
target = [target]

scores = self.has_answer(target, input)
return {"recallEM": int(scores)}


class BERTScore(MetricBase):
"""
BERTScore metric, based on the BERTScore library.
Expand Down
60 changes: 60 additions & 0 deletions ragfit/processing/global_steps/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,42 @@
from datasets import concatenate_datasets

from ..step import GlobalStep
from .filters import filters


class FilterDataset(GlobalStep):
"""
Step for filtering a dataset.
"""

def __init__(self, filter_fn, **kwargs):
"""
Args:
filter_fn (function): Function to filter the dataset.
"""
super().__init__(**kwargs)
self.filter_fn = filters[filter_fn]

def process(self, dataset_name, datasets, **kwargs):
datasets[dataset_name] = datasets[dataset_name].filter(self.filter_fn)


class SelectColumns(GlobalStep):
"""
Step for selecting specified columns in a dataset.
"""

def __init__(self, columns: list[str], **kwargs):
"""
Args:
columns (list): List of keys to keep in the dataset.
"""
super().__init__(**kwargs)
assert isinstance(columns, list), "columns should be a list of strings."
self.columns = columns

def process(self, dataset_name, datasets, **kwargs):
datasets[dataset_name] = datasets[dataset_name].select_columns(self.columns)


class MergeDatasets(GlobalStep):
Expand Down Expand Up @@ -29,3 +65,27 @@ def process(self, dataset_name, datasets, **kwargs):
data = data.shuffle(self.shuffle)
datasets[self.output] = data
self.completed = True


class DatasetTagger(GlobalStep):
"""
Class to tag each example with the dataset name. Useful when running aggregations.
"""

def __init__(self, keyword="source", **kwargs):
"""
Args:
keyword (str): The key to use for tagging. Default is "source".
"""
super().__init__(**kwargs)
self.keyword = keyword

def tag(self, item, dataset_name):
item[self.keyword] = dataset_name
return item

def process(self, dataset_name, datasets, **kwargs):
datasets[dataset_name] = datasets[dataset_name].map(
lambda item: self.tag(item, dataset_name),
load_from_cache_file=False,
)
8 changes: 8 additions & 0 deletions ragfit/processing/global_steps/filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Module containing filters"""


def msmarco_positive_filter(x):
return 1 in x["passages"]["is_selected"]


filters = {"MSMARCO": msmarco_positive_filter}
2 changes: 1 addition & 1 deletion ragfit/processing/global_steps/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def process_all(self, dataset, datasets, **kwargs):
if self.shuffle:
dataset = dataset.shuffle(seed=self.shuffle)
if self.limit:
dataset = dataset.select(range(self.limit))
dataset = dataset.select(range(min(len(dataset), self.limit)))
return dataset


Expand Down
20 changes: 20 additions & 0 deletions ragfit/processing/local_steps/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,23 @@ def __init__(self, input_key, output_key, string_join=", ", **kwargs):
def process_item(self, item, index, datasets, **kwargs):
item[self.output_key] = self.string_join.join(item[self.input_key])
return item


class UpdateField(LocalStep):
"""
Class to update a field in the dataset with a new value.
"""

def __init__(self, input_key: str, value, **kwargs):
"""
Args:
input_key (str): example key to change.
value: New value to set for the field.
"""
super().__init__(**kwargs)
self.input_key = input_key
self.value = value

def process_item(self, item, index, datasets, **kwargs):
item[self.input_key] = self.value
return item
33 changes: 33 additions & 0 deletions ragfit/processing/local_steps/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Module for inference steps, which can use LLM output to augment the data."""

from ragfit.models.vllm import VLLMInference

from ..step import LocalStep


class HFStep(LocalStep):
"""
Class for running inference with a Hugging Face model based on the vLLM engine.
"""

def __init__(self, input_key, output_key, model_kwargs, **kwargs):
"""
Initialize the HFStep class.
Args:
input_key (str): The key for the input text to be served as the prompt.
output_key (str): The key for for saving the generated text.
model_kwargs (dict): The keyword arguments to pass to the vLLM model.
**kwargs: Additional keyword arguments to pass to the LocalStep.
"""
super().__init__(**kwargs)
self.input_key = input_key
self.output_key = output_key
self.model_kwargs = model_kwargs
self.model = VLLMInference(**model_kwargs)

def process_item(self, item, index, datasets, **kwargs):
prompt = item[self.input_key]
response = self.model.generate(prompt)
item[self.output_key] = response
return item
5 changes: 3 additions & 2 deletions ragfit/processing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def gen_cache_fn(self, step, index, dataset_name):
Returns a string.
"""
return (
f"{self.output_path}/{self.name}"
f"_{index}_{type(step).__name__}"
f"{self.output_path}/cache"
f"_{self.name}_{index}"
f"_{type(step).__name__}"
f"_{dataset_name}_{step.get_hash()}.json"
)

Expand Down
2 changes: 1 addition & 1 deletion ragfit/processing/prompts/prompt_instructions/qa-short.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
You are a helpful question answerer who can provide an answer given a question and relevant context. Please answer shortly as possible and don't repeat the question.
You are a helpful question answerer who can provide an answer given a question and relevant context. Answer the following question with a short span. The answer needs to be just in a few words.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
You are a helpful question answerer who can provide an answer given a question and relevant context. Please answer with "yes", "no" or "maybe", if there is not enough information to answer the question.
6 changes: 6 additions & 0 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
line-length = 90

[lint]
select = ["E", "F", "W", "I", "N", "Q"]
ignore = ["E203", "F841", "E501", "F821"]
exclude = ["*.ipynb"]
Loading

0 comments on commit b5ed97f

Please sign in to comment.