Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding deprecation warnings for putting question in generate #71

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 69 additions & 36 deletions src/beyondllm/generator/generate.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from beyondllm.llms import GeminiModel
from beyondllm.utils import CONTEXT_RELEVENCE,GROUNDEDNESS,ANSWER_RELEVENCE, GROUND_TRUTH

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

resolve merge conflict

import os,re
import os
import re
import warnings
import numpy as np
from typing import List
from typing import List, Dict
import pysbd
from .base import BaseGenerator,GeneratorConfig
from dataclasses import dataclass,field
from dataclasses import dataclass, field

from beyondllm.llms import GeminiModel
from beyondllm.utils import CONTEXT_RELEVENCE, GROUNDEDNESS, ANSWER_RELEVENCE, GROUND_TRUTH
from beyondllm.memory.base import BaseMemory
from .base import BaseGenerator, GeneratorConfig


warnings.filterwarnings('always', category=DeprecationWarning, module=__name__)

def default_llm():
api_key = os.getenv('GOOGLE_API_KEY')
Expand All @@ -34,34 +39,42 @@ def thresholdCheck(score):
@dataclass
class Generate:
"""
### Example:
### Example for new usage:
```
>>> from beyondllm.generator import generate
>>> pipeline = generate(retriever=retriever,question) #llm = ChatOpenAI
>>> from beyondllm.generator import Generate
>>> generator = Generate(retriever=retriever)
>>> response = generator.call("What is the capital of France?")
```
### for prediction:

### Example for old usage (deprecated):
```
>>> response = pipeline.call() # rag response
>>> rag_triad = pipeline.rag_triad_evals_report()
>>> context_relevancy = pipeline.get_context_relevancy_score()
>>> answer_relevancy = pipeline.get_answer_relevancy_score()
>>> groundness = pipeline.get_groundedness()
>>> from beyondllm.generator import Generate
>>> generator = Generate(retriever=retriever, question="What is the capital of France?")
>>> response = generator.call()
```
### for evaluation:
```
>>> rag_triad = generator.rag_triad_evals()
>>> context_relevancy = generator.get_context_relevancy()
>>> answer_relevancy = generator.get_answer_relevancy()
>>> groundness = generator.get_groundedness()
```
"""
question: str
question: str = None
system_prompt: str = None
retriever:str = ''
retriever: str = ''
llm: GeminiModel = field(default_factory=default_llm)
memory: BaseMemory = None

def __post_init__(self):
self.pipeline()

def pipeline(self):
self.CONTEXT = [node_with_score.node.text for node_with_score in self.retriever.retrieve(self.question)]
temp = ".".join(self.CONTEXT)

if self.question is not None:
warnings.warn(
"Initializing Generate with 'question' is deprecated and will be removed in a future version. "
"Use the new 'call' method with a question parameter instead.",
DeprecationWarning,
stacklevel=2
)
self._legacy_init()

if self.system_prompt is None:
self.system_prompt = """
You are an AI assistant who always answer to the user QUERY within the given CONTEXT \
Expand All @@ -72,6 +85,29 @@ def pipeline(self):
If you FAIL to execute this task, you will be fired and you will suffer
"""

def _legacy_init(self):
self.CONTEXT = [node_with_score.node.text for node_with_score in self.retriever.retrieve(self.question)]
self.RESPONSE = self._generate_response(self.question)

def call(self, question: str = None) -> str:
if question is None:
if self.question is None:
raise ValueError("No question provided. Either initialize with a question or pass one to call().")
warnings.warn(
"Using call() without a question parameter is deprecated and will be removed in a future version. "
"Please use call(question) instead.",
DeprecationWarning,
stacklevel=2
)
return self.RESPONSE

self.question = question
self.CONTEXT = [node_with_score.node.text for node_with_score in self.retriever.retrieve(self.question)]
self.RESPONSE = self._generate_response(question)
return self.RESPONSE

def _generate_response(self, question: str) -> str:
temp = ".".join(self.CONTEXT)
memory_content = ""
if self.memory is not None:
memory_content = self.memory.get_memory()
Expand All @@ -82,17 +118,15 @@ def pipeline(self):
CONTEXT: {temp}
--------------------
CHAT HISTORY: {memory_content}
QUERY: {self.question}
QUERY: {question}
"""

self.RESPONSE = self.llm.predict(template)
# Store the question and response in memory
response = self.llm.predict(template)

if self.memory is not None:
self.memory.add_to_memory(question=self.question, response=self.RESPONSE)
return self.CONTEXT,self.RESPONSE

def call(self):
return self.RESPONSE
self.memory.add_to_memory(question=question, response=response)

return response

def get_rag_triad_evals(self, llm = None):
if llm is None:
Expand Down Expand Up @@ -127,7 +161,7 @@ def get_answer_relevancy(self, llm = None):
if llm is None:
llm = self.llm
try:
score_str = llm.predict(ANSWER_RELEVENCE.format(question=self.question, context= self.RESPONSE))
score_str = llm.predict(ANSWER_RELEVENCE.format(question=self.question, context=self.RESPONSE))
score = float(extract_number(score_str))

return f"Answer relevancy Score: {round(score, 1)}\n{thresholdCheck(score)}"
Expand All @@ -141,7 +175,6 @@ def get_groundedness(self, llm = None):
statements = sent_tokenize(self.RESPONSE)
scores = []
for statement in statements:

score_response = llm.predict(GROUNDEDNESS.format(statement=statement, context=self.CONTEXT))
score = extract_number(score_response)
scores.append(score)
Expand All @@ -163,7 +196,7 @@ def get_ground_truth(self, answer:str, llm = None):
score = extract_number(score_str)

return f"Ground truth score: {round(score, 1)}\n{thresholdCheck(score)}"


@staticmethod
def load_from_kwargs(self,kwargs):
Expand Down