Skip to content

Commit

Permalink
Merge pull request #21 from dexhunter/fix/beta_models
Browse files Browse the repository at this point in the history
🐛 fix model issues with beta limitation
  • Loading branch information
dexhunter authored Nov 3, 2024
2 parents 21ab47d + d4ec913 commit 86339ec
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 22 deletions.
8 changes: 8 additions & 0 deletions aide/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def query(
"max_tokens": max_tokens,
}

# Handle models with beta limitations
# ref: https://platform.openai.com/docs/guides/reasoning/beta-limitations
if model.startswith("o1-"):
if system_message:
user_message = system_message
system_message = None
model_kwargs["temperature"] = 1

query_func = backend_anthropic.query if "claude-" in model else backend_openai.query
output, req_time, in_tok_count, out_tok_count, info = query_func(
system_message=compile_prompt_to_md(system_message) if system_message else None,
Expand Down
27 changes: 17 additions & 10 deletions aide/backend/backend_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@

import time

from anthropic import Anthropic, RateLimitError
from .utils import FunctionSpec, OutputType, opt_messages_to_list
from funcy import notnone, once, retry, select_values
from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
import anthropic

_client: Anthropic = None # type: ignore
_client: anthropic.Anthropic = None # type: ignore

RATELIMIT_RETRIES = 5
retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore
ANTHROPIC_TIMEOUT_EXCEPTIONS = (
anthropic.RateLimitError,
anthropic.APIConnectionError,
anthropic.APITimeoutError,
anthropic.InternalServerError,
)


@once
def _setup_anthropic_client():
global _client
_client = Anthropic()
_client = anthropic.Anthropic(max_retries=0)


@retry_exp
def query(
system_message: str | None,
user_message: str | None,
Expand Down Expand Up @@ -48,7 +50,12 @@ def query(
messages = opt_messages_to_list(None, user_message)

t0 = time.time()
message = _client.messages.create(messages=messages, **filtered_kwargs) # type: ignore
message = backoff_create(
_client.messages.create,
ANTHROPIC_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
req_time = time.time() - t0

assert len(message.content) == 1 and message.content[0].type == "text"
Expand Down
28 changes: 17 additions & 11 deletions aide/backend/backend_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,26 @@
import logging
import time

from .utils import FunctionSpec, OutputType, opt_messages_to_list
from funcy import notnone, once, retry, select_values
from openai import OpenAI, RateLimitError
from .utils import FunctionSpec, OutputType, opt_messages_to_list, backoff_create
from funcy import notnone, once, select_values
import openai

logger = logging.getLogger("aide")

_client: OpenAI = None # type: ignore

RATELIMIT_RETRIES = 5
retry_exp = retry(RATELIMIT_RETRIES, errors=RateLimitError, timeout=lambda a: 2 ** (a + 1)) # type: ignore
_client: openai.OpenAI = None # type: ignore

OPENAI_TIMEOUT_EXCEPTIONS = (
openai.RateLimitError,
openai.APIConnectionError,
openai.APITimeoutError,
openai.InternalServerError,
)

@once
def _setup_openai_client():
global _client
_client = OpenAI(max_retries=3)

_client = openai.OpenAI(max_retries=0)

@retry_exp
def query(
system_message: str | None,
user_message: str | None,
Expand All @@ -40,7 +41,12 @@ def query(
filtered_kwargs["tool_choice"] = func_spec.openai_tool_choice_dict

t0 = time.time()
completion = _client.chat.completions.create(messages=messages, **filtered_kwargs) # type: ignore
completion = backoff_create(
_client.chat.completions.create,
OPENAI_TIMEOUT_EXCEPTIONS,
messages=messages,
**filtered_kwargs,
)
req_time = time.time() - t0

choice = completion.choices[0]
Expand Down
21 changes: 21 additions & 0 deletions aide/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@
OutputType = str | FunctionCallType


import backoff
import logging
from typing import Callable

logger = logging.getLogger("aide")


@backoff.on_predicate(
wait_gen=backoff.expo,
max_value=60,
factor=1.5,
)
def backoff_create(
create_fn: Callable, retry_exceptions: list[Exception], *args, **kwargs
):
try:
return create_fn(*args, **kwargs)
except retry_exceptions as e:
logger.info(f"Backoff exception: {e}")
return False

def opt_messages_to_list(
system_message: str | None, user_message: str | None
) -> list[dict[str, str]]:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ pdf2image
PyPDF
pyocr
pyarrow
xlrd
xlrd
backoff

0 comments on commit 86339ec

Please sign in to comment.