-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfactchecker.py
46 lines (37 loc) · 2.03 KB
/
factchecker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from langchain_community.llms import OpenAI
from langchain.chains import LLMChain, LLMSummarizationCheckerChain, SequentialChain
from langchain_core.prompts import PromptTemplate
from pymongo import MongoClient
import os
from dotenv import load_dotenv
def fact_check(answer):
llm = OpenAI(temperature=0)
dict = {"answer": answer}
template = """Here is the answer: {answer} to the question. Make a bullet point list of the assumptions in the answer.\n\n"""
prompt_template = PromptTemplate(input_variables=["answer"], template=template)
assumptions_chain = LLMChain(llm=llm, prompt=prompt_template, output_key="assertions")
template = """Here is a bullet point list of assertions:
{assertions}
For each assertion, determine whether it is true or false. If it is false, explain why.\n\n"""
prompt_template = PromptTemplate(input_variables=["assertions"], template=template)
fact_checker_chain = LLMChain(llm=llm, prompt=prompt_template, output_key="facts")
template = """In light of the above facts, how would you answer the question and explain your answer."""
template = """{facts}\n""" + template
prompt_template = PromptTemplate(input_variables=["facts"], template=template)
answer_chain = LLMChain(llm=llm, prompt=prompt_template, output_key="output")
overall_chain = SequentialChain(chains=[assumptions_chain, fact_checker_chain, answer_chain],
input_variables=["answer"],
output_variables=["assertions", "facts", "output"],
verbose=True)
return overall_chain(dict)
def checker_chain(answer):
load_dotenv()
client = MongoClient(os.getenv("MONGO_CONNECTION_STRING"))
db = client["answer_checks"]
col = db["checks"]
llm = OpenAI(temperature=0)
checker_chain = LLMSummarizationCheckerChain.from_llm(llm, max_checks=3, verbose=True)
checks = checker_chain.run(answer)
dict = {"answer": answer, "checks": checks}
col.insert_one(dict)
print(checks)