Skip to content

Commit

Permalink
added inference kwargs to control temperature, max new tokens, etc...
Browse files Browse the repository at this point in the history
  • Loading branch information
joaomsimoes committed Jun 21, 2024
1 parent c5ec185 commit dfb8ba9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions keybert/llm/_textgenerationinference.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,18 @@ class Keywords(BaseModel):
def __init__(self,
client: InferenceClient,
prompt: str = None,
client_kwargs: Mapping[str, Any] = {},
json_schema: BaseModel = Keywords
):
self.client = client
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.client_kwargs = client_kwargs
self.json_schema = json_schema

def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
def extract_keywords(
self,
documents: List[str], candidate_keywords: List[List[str]] = None,
inference_kwargs: Mapping[str, Any] = {}
):
""" Extract topics
Arguments:
Expand All @@ -116,7 +118,7 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
response = self.client.text_generation(
prompt=prompt,
grammar={"type": "json", "value": self.json_schema.schema()},
**self.client_kwargs
**inference_kwargs
)
all_keywords = json.loads(response)["keywords"]

Expand Down

0 comments on commit dfb8ba9

Please sign in to comment.