Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FaqGen Accuracy scripts & Refine Ragas #91

Merged
merged 5 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions evals/metrics/ragas/ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,14 @@ def __init__(
self.embeddings = embeddings
self.metrics = metrics
self.validated_list = [
"answer_relevancy",
"faithfulness",
"answer_correctness",
"answer_relevancy",
"answer_similarity",
"context_precision",
"context_relevancy",
"context_recall",
"faithfulness",
"context_utilization",
"reference_free_rubrics_score",
]

async def a_measure(self, test_case: Dict):
Expand All @@ -55,8 +56,9 @@ def measure(self, test_case: Dict):
answer_similarity,
context_precision,
context_recall,
context_relevancy,
context_utilization,
faithfulness,
reference_free_rubrics_score,
)

except ModuleNotFoundError:
Expand All @@ -67,8 +69,14 @@ def measure(self, test_case: Dict):
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install dataset")
self.metrics_instance = {
"answer_correctness": answer_correctness,
"answer_relevancy": answer_relevancy,
"answer_similarity": answer_similarity,
"context_precision": context_precision,
"context_recall": context_recall,
"faithfulness": faithfulness,
"context_utilization": context_utilization,
"reference_free_rubrics_score": reference_free_rubrics_score,
}

# Set LLM model
Expand Down Expand Up @@ -101,7 +109,7 @@ def measure(self, test_case: Dict):
else:
if metric == "answer_relevancy" and self.embeddings is None:
raise ValueError("answer_relevancy metric need provide embeddings model.")
tmp_metrics.append(metric)
tmp_metrics.append(self.metrics_instance[metric])
self.metrics = tmp_metrics
else:
self.metrics = [
Expand All @@ -115,10 +123,10 @@ def measure(self, test_case: Dict):
]

data = {
"question": test_case["input"],
"contexts": test_case["retrieval_context"],
"answer": test_case["actual_output"],
"ground_truth": test_case["expected_output"],
"question": test_case["question"],
"contexts": test_case["contexts"],
"answer": test_case["answer"],
"ground_truth": test_case["ground_truth"],
}
dataset = Dataset.from_dict(data)

Expand Down
61 changes: 61 additions & 0 deletions examples/FaqGen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
## Dataset
We evaluate performance on QA dataset [Squad_v2](https://huggingface.co/datasets/rajpurkar/squad_v2). Generate FAQs on "context" columns in validation dataset, which contains 1204 unique records.

First download dataset and put at "./data".

Extract unique "context" columns, which will be save to 'data/sqv2_context.json':
```
python get_context.py
```

## Generate FAQs

### Launch FaQGen microservice
Please refer to [FaQGen microservice](https://github.com/opea-project/GenAIComps/tree/main/comps/llms/faq-generation/tgi), set up an microservice endpoint.
```
export FAQ_ENDPOINT = "http://${your_ip}:9000/v1/faqgen"
```

### Generate FAQs with microservice
Use the microservice endpoint to generate FAQs for dataset.
```
python generate_FAQ.py
```

Post-process the output to get the right data, which will be save to 'data/sqv2_faq.json'.
```
python post_process_FAQ.py
```

## Evaluate with Ragas

### Launch TGI service
We use "mistralai/Mixtral-8x7B-Instruct-v0.1" as LLM referee to evaluate the model. First we need to launch a LLM endpoint on Gaudi.
```
export HUGGING_FACE_HUB_TOKEN="your_huggingface_token"
bash launch_tgi.sh
```
Get the endpoint:
```
export LLM_ENDPOINT = "http://${ip_address}:8082"
```

Verify the service:
```bash
curl http://${ip_address}:8082/generate \
-X POST \
-d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":128}}' \
-H 'Content-Type: application/json'
```

### Evaluate
evaluate the performance with the LLM:
```
python evaluate.py
```

### Performance Result
Here is the tested result for your reference
| answer_relevancy | faithfulness | context_utilization | reference_free_rubrics_score |
| ---- | ---- |---- |---- |
| 0.7191 | 0.9681 | 0.8964 | 4.4125|
45 changes: 45 additions & 0 deletions examples/FaqGen/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os

from langchain_community.embeddings import HuggingFaceBgeEmbeddings

from evals.metrics.ragas import RagasMetric

llm_endpoint = os.getenv("LLM_ENDPOINT", "http://0.0.0.0:8082")

f = open("data/sqv2_context.json", "r")
sqv2_context = json.load(f)

f = open("data/sqv2_faq.json", "r")
sqv2_faq = json.load(f)

templ = """Create a concise FAQs (frequently asked questions and answers) for following text:
TEXT: {text}
Do not use any prefix or suffix to the FAQ.
"""

number = 1204
question = []
answer = []
ground_truth = ["None"] * number
contexts = []
for i in range(number):
inputs = sqv2_context[str(i)]
inputs_faq = templ.format_map({"text": inputs})
actual_output = sqv2_faq[str(i)]

question.append(inputs_faq)
answer.append(actual_output)
contexts.append([inputs_faq])

embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
metrics_faq = ["answer_relevancy", "faithfulness", "context_utilization", "reference_free_rubrics_score"]
metric = RagasMetric(threshold=0.5, model=llm_endpoint, embeddings=embeddings, metrics=metrics_faq)

test_case = {"question": question, "answer": answer, "ground_truth": ground_truth, "contexts": contexts}

metric.measure(test_case)
print(metric.score)
28 changes: 28 additions & 0 deletions examples/FaqGen/generate_FAQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os
import time

import requests

llm_endpoint = os.getenv("FAQ_ENDPOINT", "http://0.0.0.0:9000/v1/faqgen")

f = open("data/sqv2_context.json", "r")
sqv2_context = json.load(f)

start_time = time.time()
headers = {"Content-Type": "application/json"}
for i in range(1204):
start_time_tmp = time.time()
print(i)
inputs = sqv2_context[str(i)]
data = {"query": inputs, "max_new_tokens": 128}
response = requests.post(llm_endpoint, json=data, headers=headers)
f = open(f"data/result/sqv2_faq_{i}", "w")
f.write(inputs)
f.write(str(response.content, encoding="utf-8"))
f.close()
print(f"Cost {time.time()-start_time_tmp} seconds")
print(f"\n Finished! \n Totally Cost {time.time()-start_time} seconds\n")
17 changes: 17 additions & 0 deletions examples/FaqGen/get_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json
import os

import pandas as pd

data_path = "./data"
data = pd.read_parquet(os.path.join(data_path, "squad_v2/squad_v2/validation-00000-of-00001.parquet"))
sq_context = list(data["context"].unique())
sq_context_d = dict()
for i in range(len(sq_context)):
sq_context_d[i] = sq_context[i]

with open(os.path.join(data_path, "sqv2_context.json"), "w") as outfile:
json.dump(sq_context_d, outfile)
28 changes: 28 additions & 0 deletions examples/FaqGen/launch_tgi.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

max_input_tokens=3072
max_total_tokens=4096
port_number=8082
model_name="mistralai/Mixtral-8x7B-Instruct-v0.1"
volume="./data"
docker run -it --rm \
--name="tgi_Mixtral" \
-p $port_number:80 \
-v $volume:/data \
--runtime=habana \
--restart always \
-e HUGGING_FACE_HUB_TOKEN=$HUGGING_FACE_HUB_TOKEN \
-e HABANA_VISIBLE_DEVICES=all \
-e OMPI_MCA_btl_vader_single_copy_mechanism=none \
-e PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
--cap-add=sys_nice \
--ipc=host \
-e HTTPS_PROXY=$https_proxy \
-e HTTP_PROXY=$https_proxy \
ghcr.io/huggingface/tgi-gaudi:2.0.1 \
--model-id $model_name \
--max-input-tokens $max_input_tokens \
--max-total-tokens $max_total_tokens \
--sharded true \
--num-shard 2
27 changes: 27 additions & 0 deletions examples/FaqGen/post_process_FAQ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import json

faq_dict = {}
fails = []
for i in range(1204):
data = open(f"data/result/sqv2_faq_{i}", "r").readlines()
result = data[-6][6:]
# print(result)
if "LLMChain/final_output" not in result:
print(f"error1: fail for {i}")
fails.append(i)
continue
try:
result2 = json.loads(result)
result3 = result2["ops"][0]["value"]["text"]
faq_dict[str(i)] = result3
except:
print(f"error2: fail for {i}")
fails.append(i)
continue
with open("data/sqv2_faq.json", "w") as outfile:
json.dump(faq_dict, outfile)
print("Failure index:")
print(fails)
8 changes: 4 additions & 4 deletions tests/test_ragas.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def test_ragas(self):
embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5")
metric = RagasMetric(threshold=0.5, model="http://localhost:8008", embeddings=embeddings)
test_case = {
"input": ["What if these shoes don't fit?"],
"actual_output": [actual_output],
"expected_output": [expected_output],
"retrieval_context": [retrieval_context],
"question": ["What if these shoes don't fit?"],
"answer": [actual_output],
"ground_truth": [expected_output],
"contexts": [retrieval_context],
}

metric.measure(test_case)
Expand Down