Skip to content

Commit

Permalink
fix: update import for openai errors (#196)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sparkier authored Dec 20, 2023
1 parent eb42554 commit 1c4e623
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions zeno_build/models/providers/openai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,24 @@
import aiolimiter
import openai
from aiohttp import ClientSession
from openai import error
from openai import (
APIConnectionError,
APIError,
BadRequestError,
RateLimitError,
Timeout,
)
from tqdm.asyncio import tqdm_asyncio

from zeno_build.models import lm_config
from zeno_build.prompts import chat_prompt

ERROR_ERRORS_TO_MESSAGES = {
error.InvalidRequestError: "OpenAI API Invalid Request: Prompt was filtered",
error.RateLimitError: "OpenAI API rate limit exceeded. Sleeping for 10 seconds.",
error.APIConnectionError: "OpenAI API Connection Error: Error Communicating with OpenAI", # noqa E501
error.Timeout: "OpenAI APITimeout Error: OpenAI Timeout",
error.ServiceUnavailableError: "OpenAI service unavailable error: {e}",
error.APIError: "OpenAI API error: {e}",
BadRequestError: "OpenAI API Invalid Request: Prompt was filtered",
RateLimitError: "OpenAI API rate limit exceeded. Sleeping for 10 seconds.",
APIConnectionError: "OpenAI API Connection Error: Error Communicating with OpenAI", # noqa E501
Timeout: "OpenAI APITimeout Error: OpenAI Timeout",
APIError: "OpenAI API error: {e}",
}


Expand All @@ -45,9 +50,9 @@ async def _throttled_openai_completion_acreate(
n=n,
)
except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e:
if isinstance(e, (error.ServiceUnavailableError, error.APIError)):
if isinstance(e, APIError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e))
elif isinstance(e, error.InvalidRequestError):
elif isinstance(e, BadRequestError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
return {
"choices": [
Expand Down Expand Up @@ -136,9 +141,9 @@ async def _throttled_openai_chat_completion_acreate(
n=n,
)
except tuple(ERROR_ERRORS_TO_MESSAGES.keys()) as e:
if isinstance(e, (error.ServiceUnavailableError, error.APIError)):
if isinstance(e, APIError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)].format(e=e))
elif isinstance(e, error.InvalidRequestError):
elif isinstance(e, BadRequestError):
logging.warning(ERROR_ERRORS_TO_MESSAGES[type(e)])
return {
"choices": [
Expand Down

0 comments on commit 1c4e623

Please sign in to comment.