Skip to content

Commit

Permalink
urllib to join ollama url domain to endpoint in place of string format
Browse files Browse the repository at this point in the history
  • Loading branch information
Kieran-Sears committed Nov 12, 2024
1 parent 59a46d0 commit 8f01a46
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions goldenverba/components/embedding/OllamaEmbedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import requests
from wasabi import msg
import aiohttp
from urllib.parse import urljoin

from goldenverba.components.interfaces import Embedding
from goldenverba.components.types import InputConfig
Expand Down Expand Up @@ -33,7 +34,7 @@ async def vectorize(self, config: dict, content: list[str]) -> list[float]:
data = {"model": model, "input": content}

async with aiohttp.ClientSession() as session:
async with session.post(self.url + "/api/embed", json=data) as response:
async with session.post(urljoin(self.url, "/api/embed"), json=data) as response:
response.raise_for_status()
data = await response.json()
embeddings = data.get("embeddings", [])
Expand All @@ -42,7 +43,7 @@ async def vectorize(self, config: dict, content: list[str]) -> list[float]:

def get_models(url: str):
try:
response = requests.get(url + "/api/tags")
response = requests.get(urljoin(url, "/api/tags"))
models = [model.get("name") for model in response.json().get("models")]
if len(models) > 0:
return models
Expand Down
4 changes: 2 additions & 2 deletions goldenverba/components/generation/OllamaGenerator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import aiohttp
from urllib.parse import urljoin
from typing import List, Dict, AsyncGenerator

from goldenverba.components.interfaces import Generator
Expand Down Expand Up @@ -35,7 +36,6 @@ async def generate_stream(
conversation: List[Dict] = [],
) -> AsyncGenerator[Dict, None]:
model = config.get("Model").value
url = f"{self.url}/api/chat"
system_message = config.get("System Message").value

if not self.url:
Expand All @@ -47,7 +47,7 @@ async def generate_stream(

try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as response:
async with session.post(urljoin(self.url, "/api/chat"), json=data) as response:
async for line in response.content:
if line.strip():
yield self._process_response(line)
Expand Down

0 comments on commit 8f01a46

Please sign in to comment.