From 59482cd4140f3b7775f595f8db6d5cccb8bcf1c6 Mon Sep 17 00:00:00 2001 From: Valeryi Date: Wed, 24 Jul 2024 21:31:39 +0100 Subject: [PATCH] n_choices support --- .../webgui/selfhost_fastapi_completions.py | 27 +++++++++---------- setup.py | 2 +- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/refact_webgui/webgui/selfhost_fastapi_completions.py b/refact_webgui/webgui/selfhost_fastapi_completions.py index 8d29f525..19cff522 100644 --- a/refact_webgui/webgui/selfhost_fastapi_completions.py +++ b/refact_webgui/webgui/selfhost_fastapi_completions.py @@ -511,7 +511,6 @@ def compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n) -> Dict[ model_dict = self._model_assigner.models_db_with_passthrough.get(post.model, {}) async def litellm_streamer(): - final_msg = {} generated_tokens_n = 0 try: self._integrations_env_setup() @@ -521,7 +520,8 @@ async def litellm_streamer(): max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens), tools=post.tools, tool_choice=post.tool_choice, - stop=post.stop + stop=post.stop, + n=post.n, ) finish_reason = None async for model_response in response: @@ -533,18 +533,14 @@ async def litellm_streamer(): if text := delta.get("content"): generated_tokens_n += litellm.token_counter(model_name, text=text) - if finish_reason: - final_msg = data - break - except json.JSONDecodeError: data = {"choices": [{"finish_reason": finish_reason}]} yield prefix + json.dumps(data) + postfix - if final_msg: - usage_dict = compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n) - final_msg.update(usage_dict) - yield prefix + json.dumps(final_msg) + postfix + final_msg = {"choices": []} + usage_dict = compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n) + final_msg.update(usage_dict) + yield prefix + json.dumps(final_msg) + postfix # NOTE: DONE needed by refact-lsp server yield prefix + "[DONE]" + postfix @@ -563,15 +559,16 @@ async def litellm_non_streamer(): max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens), tools=post.tools, tool_choice=post.tool_choice, - stop=post.stop + stop=post.stop, + n=post.n, ) finish_reason = None try: data = model_response.dict() - choice0 = data["choices"][0] - if text := choice0.get("message", {}).get("content"): - generated_tokens_n = litellm.token_counter(model_name, text=text) - finish_reason = choice0["finish_reason"] + for choice in data.get("choices", []): + if text := choice.get("message", {}).get("content"): + generated_tokens_n += litellm.token_counter(model_name, text=text) + finish_reason = choice.get("finish_reason") usage_dict = compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n) data.update(usage_dict) except json.JSONDecodeError: diff --git a/setup.py b/setup.py index 0965cecd..d85efd7f 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ class PyPackage: "starlette==0.27.0", "uvicorn", "uvloop", "termcolor", "python-multipart", "more_itertools", "scyllapy==1.3.0", "pandas>=2.0.3", # NOTE: litellm has bug with anthropic streaming, so we're staying on this version for now - "litellm==1.34.42", + "litellm==1.42.0", ], requires_packages=["refact_known_models", "refact_utils"], data=["webgui/static/*", "webgui/static/components/modals/*",