Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new LLM client architecture #570

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 48 additions & 108 deletions src/codegate/llm_utils/llmclient.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
import json
from typing import Any, Dict, Optional

import litellm
import structlog
from litellm import acompletion
from ollama import Client as OllamaClient

from codegate.config import Config
from codegate.inference import LlamaCppInferenceEngine
from codegate.llmclient.base import Message, LLMProvider
from codegate.providers.litellmshim.bridge import LiteLLMBridgeProvider

logger = structlog.get_logger("codegate")

litellm.drop_params = True


class LLMClient:
"""
Base class for LLM interactions handling both local and cloud providers.

This is a kludge before we refactor our providers a bit to be able to pass
in all the parameters we need.
"""
"""Base class for LLM interactions handling both local and cloud providers."""

@staticmethod
def _create_provider(
provider: str,
model: str = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
) -> Optional[LLMProvider]:
if provider == "llamacpp":
return None # Handled separately for now
return LiteLLMBridgeProvider(
api_key=api_key or "",
base_url=base_url,
default_model=model
)

@staticmethod
async def complete(
Expand All @@ -33,42 +39,41 @@ async def complete(
extra_headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Send a completion request to either local or cloud LLM.

Args:
content: The user message content
system_prompt: The system prompt to use
provider: "local" or "litellm"
model: Model identifier
api_key: API key for cloud providers
base_url: Base URL for cloud providers
**kwargs: Additional arguments for the completion request

Returns:
Parsed response from the LLM
"""
if provider == "llamacpp":
return await LLMClient._complete_local(content, system_prompt, model, **kwargs)
return await LLMClient._complete_litellm(
content,
system_prompt,
provider,
model,
api_key,
base_url,
extra_headers,
**kwargs,
)

llm_provider = LLMClient._create_provider(provider, model, api_key, base_url)

try:
messages = [
Message(role="system", content=system_prompt),
Message(role="user", content=content)
]

response = await llm_provider.chat(
messages=messages,
temperature=kwargs.get("temperature", 0),
stream=False,
extra_headers=extra_headers,
**kwargs
)

return json.loads(response.message.content)

except Exception as e:
logger.error(f"LLM completion failed {model} ({content}): {e}")
raise e
finally:
await llm_provider.close()

@staticmethod
async def _create_request(
content: str, system_prompt: str, model: str, **kwargs
async def _complete_local(
content: str,
system_prompt: str,
model: str,
**kwargs,
) -> Dict[str, Any]:
"""
Private method to create a request dictionary for LLM completion.
"""
return {
request = {
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
Expand All @@ -79,16 +84,6 @@ async def _create_request(
"temperature": kwargs.get("temperature", 0),
}

@staticmethod
async def _complete_local(
content: str,
system_prompt: str,
model: str,
**kwargs,
) -> Dict[str, Any]:
# Use the private method to create the request
request = await LLMClient._create_request(content, system_prompt, model, **kwargs)

inference_engine = LlamaCppInferenceEngine()
result = await inference_engine.chat(
f"{Config.get_config().model_base_path}/{request['model']}.gguf",
Expand All @@ -98,58 +93,3 @@ async def _complete_local(
)

return json.loads(result["choices"][0]["message"]["content"])

@staticmethod
async def _complete_litellm(
content: str,
system_prompt: str,
provider: str,
model: str,
api_key: str,
base_url: Optional[str] = None,
extra_headers: Optional[Dict[str, str]] = None,
**kwargs,
) -> Dict[str, Any]:
# Use the private method to create the request
request = await LLMClient._create_request(content, system_prompt, model, **kwargs)

# We should reuse the same logic in the provider
# but let's do that later
if provider == "vllm":
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
else:
if not model.startswith(f"{provider}/"):
model = f"{provider}/{model}"

try:
if provider == "ollama":
model = model.split("/")[-1]
response = OllamaClient(host=base_url).chat(
model=model,
messages=request["messages"],
format="json",
options={"temperature": request["temperature"]},
)
content = response.message.content
else:
response = await acompletion(
model=model,
messages=request["messages"],
api_key=api_key,
temperature=request["temperature"],
base_url=base_url,
response_format=request["response_format"],
extra_headers=extra_headers,
)
content = response["choices"][0]["message"]["content"]

# Clean up code blocks if present
if content.startswith("```"):
content = content.split("\n", 1)[1].rsplit("```", 1)[0].strip()

return json.loads(content)

except Exception as e:
logger.error(f"LiteLLM completion failed {model} ({content}): {e}")
raise e
94 changes: 94 additions & 0 deletions src/codegate/llmclient/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from abc import ABC, abstractmethod
from typing import AsyncIterator, Dict, List, Optional, Union
from dataclasses import dataclass

@dataclass
class Message:
"""Represents a chat message."""
role: str
content: str

@dataclass
class CompletionResponse:
"""Represents a completion response from an LLM."""
text: str
model: str
usage: Dict[str, int]

@dataclass
class ChatResponse:
"""Represents a chat response from an LLM."""
message: Message
model: str
usage: Dict[str, int]

class LLMProvider(ABC):
"""Abstract base class for LLM providers."""

def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
default_model: Optional[str] = None
):
"""Initialize the LLM provider.

Args:
api_key: API key for authentication
base_url: Optional custom base URL for the API
default_model: Optional default model to use
"""
self.api_key = api_key
self.base_url = base_url
self.default_model = default_model

@abstractmethod
async def chat(
self,
messages: List[Message],
model: Optional[str] = None,
temperature: float = 0.7,
stream: bool = False,
**kwargs
) -> Union[ChatResponse, AsyncIterator[ChatResponse]]:
"""Send a chat request to the LLM.

Args:
messages: List of messages in the conversation
model: Optional model override
temperature: Sampling temperature
stream: Whether to stream the response
**kwargs: Additional provider-specific parameters

Returns:
ChatResponse or AsyncIterator[ChatResponse] if streaming
"""
pass

@abstractmethod
async def complete(
self,
prompt: str,
model: Optional[str] = None,
temperature: float = 0.7,
stream: bool = False,
**kwargs
) -> Union[CompletionResponse, AsyncIterator[CompletionResponse]]:
"""Send a completion request to the LLM.

Args:
prompt: The text prompt
model: Optional model override
temperature: Sampling temperature
stream: Whether to stream the response
**kwargs: Additional provider-specific parameters

Returns:
CompletionResponse or AsyncIterator[CompletionResponse] if streaming
"""
pass

@abstractmethod
async def close(self) -> None:
"""Close any open connections."""
pass
Empty file.
42 changes: 42 additions & 0 deletions src/codegate/llmclient/normalizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict

from ..types import Message, NormalizedRequest, ChatResponse

class ModelInputNormalizer(ABC):
@abstractmethod
def normalize(self, data: Dict[str, Any]) -> NormalizedRequest:
"""Convert provider-specific request format to SimpleModelRouter format."""
pass

@abstractmethod
def denormalize(self, data: NormalizedRequest) -> Dict[str, Any]:
"""Convert SimpleModelRouter format back to provider-specific request format."""
pass

class ModelOutputNormalizer(ABC):
@abstractmethod
def normalize_streaming(
self,
model_reply: AsyncIterator[Any]
) -> AsyncIterator[ChatResponse]:
"""Convert provider-specific streaming response to SimpleModelRouter format."""
pass

@abstractmethod
def normalize(self, model_reply: Any) -> ChatResponse:
"""Convert provider-specific response to SimpleModelRouter format."""
pass

@abstractmethod
def denormalize(self, normalized_reply: ChatResponse) -> Dict[str, Any]:
"""Convert SimpleModelRouter format back to provider-specific response format."""
pass

@abstractmethod
def denormalize_streaming(
self,
normalized_reply: AsyncIterator[ChatResponse]
) -> AsyncIterator[Any]:
"""Convert SimpleModelRouter streaming response back to provider-specific format."""
pass
Loading
Loading