Skip to content

Commit

Permalink
feat(corpora_ai): support azure_endpoint for openai as well (#56)
Browse files Browse the repository at this point in the history
* try to build the binaries, Azure OpenAI

* punt on xcompile - our other pipeline is better alpine
  • Loading branch information
skyl authored Nov 19, 2024
1 parent ac277bc commit 6379ca3
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
11 changes: 7 additions & 4 deletions py/packages/corpora_ai/provider_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from typing import Optional
from corpora_ai.llm_interface import LLMBaseInterface

# Import provider-specific clients
try:
from corpora_ai_openai.llm_client import OpenAIClient
except ImportError:
OpenAIClient = None # This handles cases where the OpenAI client isn't installed
OpenAIClient = None

# Future imports for other providers, e.g., Anthropic or Cohere, would follow the same pattern
# Future imports for other providers,
# e.g., Anthropic or Cohere, would follow the same pattern


def load_llm_provider() -> Optional[LLMBaseInterface]:
Expand All @@ -25,7 +25,10 @@ def load_llm_provider() -> Optional[LLMBaseInterface]:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set.")
return OpenAIClient(api_key=api_key)
return OpenAIClient(
api_key=api_key,
azure_endpoint=os.getenv("OPENAI_AZURE_ENDPOINT", None),
)

# Placeholder for additional providers (e.g., Anthropic)
# elif provider_name == "anthropic" and AnthropicClient:
Expand Down
6 changes: 5 additions & 1 deletion py/packages/corpora_ai/test_provider_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import unittest
from unittest.mock import patch, MagicMock

from openai import azure_endpoint
from corpora_ai.provider_loader import load_llm_provider
from corpora_ai.llm_interface import LLMBaseInterface

Expand All @@ -20,7 +22,9 @@ def test_load_openai_provider_success(self, MockOpenAIClient):

provider = load_llm_provider()

MockOpenAIClient.assert_called_once_with(api_key="test_api_key")
MockOpenAIClient.assert_called_once_with(
api_key="test_api_key", azure_endpoint=None
)
self.assertIsInstance(provider, LLMBaseInterface)
self.assertEqual(provider, mock_client_instance)

Expand Down
13 changes: 11 additions & 2 deletions py/packages/corpora_ai_openai/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing import List, Type, TypeVar
from openai import OpenAI, OpenAIError
from openai import OpenAI, OpenAIError, AzureOpenAI
from pydantic import BaseModel

from corpora_ai.llm_interface import LLMBaseInterface, ChatCompletionTextMessage
Expand All @@ -15,8 +15,17 @@ def __init__(
api_key: str,
completion_model: str = "gpt-4o",
embedding_model: str = "text-embedding-3-small",
azure_endpoint: str = None,
):
self.client = OpenAI(api_key=api_key)
if azure_endpoint:
self.client = AzureOpenAI(
api_key=api_key,
azure_endpoint=azure_endpoint,
# What's the behavior of not pinning the API version?
# api_version="2024-10-01-preview",
)
else:
self.client = OpenAI(api_key=api_key)
self.completion_model = completion_model
self.embedding_model = embedding_model

Expand Down

0 comments on commit 6379ca3

Please sign in to comment.