Skip to content

Commit

Permalink
chore: remove ollama patch (#956)
Browse files Browse the repository at this point in the history
# Description

Tool calls will now be included in streaming responses. [ollama
0.4.6](https://github.com/ollama/ollama/releases/tag/v0.4.6), remove the
ollama patch。
  • Loading branch information
venjiang authored Dec 13, 2024
1 parent b8f33fc commit ca0e0cf
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 80 deletions.
2 changes: 1 addition & 1 deletion example/10-ai/llm-sfn-get-weather/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
var tag uint32 = 0x11

type Parameter struct {
CityName string `json:"city_name" jsonschema:"description=The name of the city to be queried"`
CityName string `json:"city_name" jsonschema:"description=The name of a city to be queried"`
}

// Description returns the description of this AI function.
Expand Down
2 changes: 1 addition & 1 deletion example/10-ai/zipper.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ bridge:

ollama:
api_endpoint: http://localhost:11434/v1
model: llama3.1
model: mistral

cerebras:
api_key:
Expand Down
2 changes: 1 addition & 1 deletion pkg/bridge/ai/provider/ollama/Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Follow the Ollama doc:
## 2. Run the model

```sh
ollama run llama3.1
ollama run mistral
```

## 3. Start YoMo Zipper
Expand Down
77 changes: 0 additions & 77 deletions pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,6 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
toolCalls = []openai.ToolCall{}
assistantMessage = openai.ChatCompletionMessage{}
)
rawReq := req
// ollama request patch
// WARN: this is a temporary solution for ollama provider
req = srv.patchOllamaRequest(req)

// 4. request first chat for getting tools
if req.Stream {
Expand Down Expand Up @@ -372,56 +368,6 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
assistantMessage = resp.Choices[0].Message
firstCallSpan.End()
reqSpan.End()
} else if rawReq.Stream {
// if raw request is stream mode, we should return the stream response
// WARN: this is a temporary solution for ollama provider
w.SetStreamHeader()
// choices
choices := make([]openai.ChatCompletionStreamChoice, 0)
for _, choice := range resp.Choices {
delta := openai.ChatCompletionStreamChoiceDelta{
Content: choice.Message.Content,
Role: choice.Message.Role,
FunctionCall: choice.Message.FunctionCall,
ToolCalls: choice.Message.ToolCalls,
}
choices = append(choices, openai.ChatCompletionStreamChoice{
Index: choice.Index,
Delta: delta,
FinishReason: choice.FinishReason,
// ContentFilterResults
})
}
// chunk response
streamRes := openai.ChatCompletionStreamResponse{
ID: resp.ID,
Object: "chat.completion.chunk",
Created: resp.Created,
Model: resp.Model,
Choices: choices,
SystemFingerprint: resp.SystemFingerprint,
// PromptAnnotations:
// PromptFilterResults:
}
_ = w.WriteStreamEvent(streamRes)
// usage
if req.StreamOptions != nil && req.StreamOptions.IncludeUsage {
streamRes = openai.ChatCompletionStreamResponse{
ID: resp.ID,
Object: "chat.completion.chunk",
Created: resp.Created,
Model: resp.Model,
SystemFingerprint: resp.SystemFingerprint,
Usage: &openai.Usage{
PromptTokens: promptUsage,
CompletionTokens: completionUsage,
TotalTokens: totalUsage,
},
}
_ = w.WriteStreamEvent(streamRes)
}
// done
return w.WriteStreamDone()
} else {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
Expand Down Expand Up @@ -452,8 +398,6 @@ func (srv *Service) GetChatCompletions(ctx context.Context, req openai.ChatCompl
if srv.provider.Name() != "anthropic" {
req.Tools = nil // reset tools field
}
// restore the original request stream field
req.Stream = rawReq.Stream

srv.logger.Debug(" #2 second call", "request", fmt.Sprintf("%+v", req))

Expand Down Expand Up @@ -692,24 +636,3 @@ func recordTTFT(ctx context.Context, tracer trace.Tracer, w *ResponseWriter) {
w.TTFT = time.Now()
time.Sleep(time.Millisecond)
}

// patchOllamaRequest patch the request for ollama provider(ollama function calling unsupported in stream mode)
func (srv *Service) patchOllamaRequest(req openai.ChatCompletionRequest) openai.ChatCompletionRequest {
srv.logger.Debug("before request",
"stream", req.Stream,
"provider", srv.provider.Name(),
fmt.Sprintf("tools[%d]", len(req.Tools)), req.Tools,
)
if !req.Stream {
return req
}
if srv.provider.Name() == "ollama" && len(req.Tools) > 0 {
req.Stream = false
}
srv.logger.Debug("patch request",
"stream", req.Stream,
"provider", srv.provider.Name(),
fmt.Sprintf("tools[%d]", len(req.Tools)), req.Tools,
)
return req
}

0 comments on commit ca0e0cf

Please sign in to comment.