Skip to content

Commit

Permalink
n_choices support
Browse files Browse the repository at this point in the history
  • Loading branch information
valaises committed Jul 24, 2024
1 parent 23ba0a7 commit 59482cd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
27 changes: 12 additions & 15 deletions refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n) -> Dict[
model_dict = self._model_assigner.models_db_with_passthrough.get(post.model, {})

async def litellm_streamer():
final_msg = {}
generated_tokens_n = 0
try:
self._integrations_env_setup()
Expand All @@ -521,7 +520,8 @@ async def litellm_streamer():
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
tools=post.tools,
tool_choice=post.tool_choice,
stop=post.stop
stop=post.stop,
n=post.n,
)
finish_reason = None
async for model_response in response:
Expand All @@ -533,18 +533,14 @@ async def litellm_streamer():
if text := delta.get("content"):
generated_tokens_n += litellm.token_counter(model_name, text=text)

if finish_reason:
final_msg = data
break

except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
yield prefix + json.dumps(data) + postfix

if final_msg:
usage_dict = compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n)
final_msg.update(usage_dict)
yield prefix + json.dumps(final_msg) + postfix
final_msg = {"choices": []}
usage_dict = compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n)
final_msg.update(usage_dict)
yield prefix + json.dumps(final_msg) + postfix

# NOTE: DONE needed by refact-lsp server
yield prefix + "[DONE]" + postfix
Expand All @@ -563,15 +559,16 @@ async def litellm_non_streamer():
max_tokens=min(model_dict.get('T_out', post.max_tokens), post.max_tokens),
tools=post.tools,
tool_choice=post.tool_choice,
stop=post.stop
stop=post.stop,
n=post.n,
)
finish_reason = None
try:
data = model_response.dict()
choice0 = data["choices"][0]
if text := choice0.get("message", {}).get("content"):
generated_tokens_n = litellm.token_counter(model_name, text=text)
finish_reason = choice0["finish_reason"]
for choice in data.get("choices", []):
if text := choice.get("message", {}).get("content"):
generated_tokens_n += litellm.token_counter(model_name, text=text)
finish_reason = choice.get("finish_reason")
usage_dict = compose_usage_dict(model_dict, prompt_tokens_n, generated_tokens_n)
data.update(usage_dict)
except json.JSONDecodeError:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class PyPackage:
"starlette==0.27.0", "uvicorn", "uvloop", "termcolor", "python-multipart", "more_itertools",
"scyllapy==1.3.0", "pandas>=2.0.3",
# NOTE: litellm has bug with anthropic streaming, so we're staying on this version for now
"litellm==1.34.42",
"litellm==1.42.0",
],
requires_packages=["refact_known_models", "refact_utils"],
data=["webgui/static/*", "webgui/static/components/modals/*",
Expand Down

0 comments on commit 59482cd

Please sign in to comment.