From fcdf8fb7aaf9f3a66bcdd8ca26349a3dfc293bee Mon Sep 17 00:00:00 2001 From: taha-aiplanet <157125263+taha-aiplanet@users.noreply.github.com> Date: Mon, 19 Aug 2024 20:53:14 +0500 Subject: [PATCH] adding deprecation warnings for putting question in generate --- src/beyondllm/generator/generate.py | 105 ++++++++++++++++++---------- 1 file changed, 69 insertions(+), 36 deletions(-) diff --git a/src/beyondllm/generator/generate.py b/src/beyondllm/generator/generate.py index b9bf127..240cb29 100644 --- a/src/beyondllm/generator/generate.py +++ b/src/beyondllm/generator/generate.py @@ -1,13 +1,18 @@ -from beyondllm.llms import GeminiModel -from beyondllm.utils import CONTEXT_RELEVENCE,GROUNDEDNESS,ANSWER_RELEVENCE, GROUND_TRUTH - -import os,re +import os +import re +import warnings import numpy as np -from typing import List +from typing import List, Dict import pysbd -from .base import BaseGenerator,GeneratorConfig -from dataclasses import dataclass,field +from dataclasses import dataclass, field + +from beyondllm.llms import GeminiModel +from beyondllm.utils import CONTEXT_RELEVENCE, GROUNDEDNESS, ANSWER_RELEVENCE, GROUND_TRUTH from beyondllm.memory.base import BaseMemory +from .base import BaseGenerator, GeneratorConfig + + +warnings.filterwarnings('always', category=DeprecationWarning, module=__name__) def default_llm(): api_key = os.getenv('GOOGLE_API_KEY') @@ -34,34 +39,42 @@ def thresholdCheck(score): @dataclass class Generate: """ - ### Example: + ### Example for new usage: ``` - >>> from beyondllm.generator import generate - >>> pipeline = generate(retriever=retriever,question) #llm = ChatOpenAI + >>> from beyondllm.generator import Generate + >>> generator = Generate(retriever=retriever) + >>> response = generator.call("What is the capital of France?") ``` - ### for prediction: - + ### Example for old usage (deprecated): ``` - >>> response = pipeline.call() # rag response - >>> rag_triad = pipeline.rag_triad_evals_report() - >>> context_relevancy = pipeline.get_context_relevancy_score() - >>> answer_relevancy = pipeline.get_answer_relevancy_score() - >>> groundness = pipeline.get_groundedness() + >>> from beyondllm.generator import Generate + >>> generator = Generate(retriever=retriever, question="What is the capital of France?") + >>> response = generator.call() + ``` + ### for evaluation: + ``` + >>> rag_triad = generator.rag_triad_evals() + >>> context_relevancy = generator.get_context_relevancy() + >>> answer_relevancy = generator.get_answer_relevancy() + >>> groundness = generator.get_groundedness() ``` """ - question: str + question: str = None system_prompt: str = None - retriever:str = '' + retriever: str = '' llm: GeminiModel = field(default_factory=default_llm) memory: BaseMemory = None def __post_init__(self): - self.pipeline() - - def pipeline(self): - self.CONTEXT = [node_with_score.node.text for node_with_score in self.retriever.retrieve(self.question)] - temp = ".".join(self.CONTEXT) - + if self.question is not None: + warnings.warn( + "Initializing Generate with 'question' is deprecated and will be removed in a future version. " + "Use the new 'call' method with a question parameter instead.", + DeprecationWarning, + stacklevel=2 + ) + self._legacy_init() + if self.system_prompt is None: self.system_prompt = """ You are an AI assistant who always answer to the user QUERY within the given CONTEXT \ @@ -72,6 +85,29 @@ def pipeline(self): If you FAIL to execute this task, you will be fired and you will suffer """ + def _legacy_init(self): + self.CONTEXT = [node_with_score.node.text for node_with_score in self.retriever.retrieve(self.question)] + self.RESPONSE = self._generate_response(self.question) + + def call(self, question: str = None) -> str: + if question is None: + if self.question is None: + raise ValueError("No question provided. Either initialize with a question or pass one to call().") + warnings.warn( + "Using call() without a question parameter is deprecated and will be removed in a future version. " + "Please use call(question) instead.", + DeprecationWarning, + stacklevel=2 + ) + return self.RESPONSE + + self.question = question + self.CONTEXT = [node_with_score.node.text for node_with_score in self.retriever.retrieve(self.question)] + self.RESPONSE = self._generate_response(question) + return self.RESPONSE + + def _generate_response(self, question: str) -> str: + temp = ".".join(self.CONTEXT) memory_content = "" if self.memory is not None: memory_content = self.memory.get_memory() @@ -82,17 +118,15 @@ def pipeline(self): CONTEXT: {temp} -------------------- CHAT HISTORY: {memory_content} - QUERY: {self.question} + QUERY: {question} """ - self.RESPONSE = self.llm.predict(template) - # Store the question and response in memory + response = self.llm.predict(template) + if self.memory is not None: - self.memory.add_to_memory(question=self.question, response=self.RESPONSE) - return self.CONTEXT,self.RESPONSE - - def call(self): - return self.RESPONSE + self.memory.add_to_memory(question=question, response=response) + + return response def get_rag_triad_evals(self, llm = None): if llm is None: @@ -127,7 +161,7 @@ def get_answer_relevancy(self, llm = None): if llm is None: llm = self.llm try: - score_str = llm.predict(ANSWER_RELEVENCE.format(question=self.question, context= self.RESPONSE)) + score_str = llm.predict(ANSWER_RELEVENCE.format(question=self.question, context=self.RESPONSE)) score = float(extract_number(score_str)) return f"Answer relevancy Score: {round(score, 1)}\n{thresholdCheck(score)}" @@ -141,7 +175,6 @@ def get_groundedness(self, llm = None): statements = sent_tokenize(self.RESPONSE) scores = [] for statement in statements: - score_response = llm.predict(GROUNDEDNESS.format(statement=statement, context=self.CONTEXT)) score = extract_number(score_response) scores.append(score) @@ -163,7 +196,7 @@ def get_ground_truth(self, answer:str, llm = None): score = extract_number(score_str) return f"Ground truth score: {round(score, 1)}\n{thresholdCheck(score)}" - + @staticmethod def load_from_kwargs(self,kwargs):