-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Related Docs - https://platform.openai.com/docs/api-reference/chat/create ## How to use this provider ```yaml bridge: ai: server: addr: localhost:8000 provider: openai providers: openai: api_key: <your-api-key> model: <gpt-3.5-turbo-1106> ``` Co-authored-by: C.C <[email protected]>
- Loading branch information
1 parent
8d8e24a
commit cad974b
Showing
4 changed files
with
277 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ bridge: | |
|
||
openai: | ||
api_key: | ||
api_endpoint: | ||
model: | ||
|
||
gemini: | ||
api_key: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,270 @@ | ||
package openai | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"os" | ||
"sync" | ||
|
||
_ "github.com/joho/godotenv/autoload" | ||
"github.com/yomorun/yomo/ai" | ||
"github.com/yomorun/yomo/core/ylog" | ||
) | ||
|
||
const APIEndpoint = "https://api.openai.com/v1/chat/completions" | ||
|
||
var fns sync.Map | ||
|
||
// Message | ||
type ChatCompletionMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
// - https://github.com/openai/openai-python/blob/main/chatml.md | ||
// - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb | ||
Name string `json:"name,omitempty"` | ||
// MultiContent []ChatMessagePart | ||
// For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. | ||
ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"` | ||
ToolCallID string `json:"tool_call_id,omitempty"` | ||
} | ||
|
||
// RequestBody is the request body | ||
type ReqBody struct { | ||
Model string `json:"model"` | ||
Messages []ChatCompletionMessage `json:"messages"` | ||
Tools []ai.ToolCall `json:"tools"` // chatCompletionTool | ||
// ToolChoice string `json:"tool_choice"` // chatCompletionFunction | ||
} | ||
|
||
// Resp is the response body | ||
type RespBody struct { | ||
ID string `json:"id"` | ||
Object string `json:"object"` | ||
Created int `json:"created"` | ||
Model string `json:"model"` | ||
Choices []RespChoice `json:"choices"` | ||
Usage RespUsage `json:"usage"` | ||
SystemFingerprint string `json:"system_fingerprint"` | ||
} | ||
|
||
// RespMessage is the message in Response | ||
type RespMessage struct { | ||
Role string `json:"role"` | ||
Content string `json:"content"` | ||
ToolCalls []ai.ToolCall `json:"tool_calls"` | ||
} | ||
|
||
// RespChoice is used to indicate the choice in Response by `FinishReason` | ||
type RespChoice struct { | ||
FinishReason string `json:"finish_reason"` | ||
Index int `json:"index"` | ||
Message ChatCompletionMessage `json:"message"` | ||
} | ||
|
||
// RespUsage is the token usage in Response | ||
type RespUsage struct { | ||
PromptTokens int `json:"prompt_tokens"` | ||
CompletionTokens int `json:"completion_tokens"` | ||
TotalTokens int `json:"total_tokens"` | ||
} | ||
|
||
// OpenAIProvider is the provider for OpenAI | ||
type OpenAIProvider struct { | ||
// APIKey is the API key for OpenAI | ||
APIKey string | ||
// Model is the model for OpenAI | ||
// eg. "gpt-3.5-turbo-1106", "gpt-4-turbo-preview", "gpt-4-vision-preview", "gpt-4" | ||
Model string | ||
} | ||
|
||
type connectedFn struct { | ||
connID uint64 | ||
tag uint32 | ||
tc ai.ToolCall | ||
} | ||
|
||
func init() { | ||
fns = sync.Map{} | ||
} | ||
|
||
// NewProvider creates a new OpenAIProvider | ||
func NewProvider(apiKey string, model string) *OpenAIProvider { | ||
if apiKey == "" { | ||
apiKey = os.Getenv("OPENAI_API_KEY") | ||
} | ||
if model == "" { | ||
model = os.Getenv("OPENAI_MODEL") | ||
} | ||
ylog.Debug("new openai provider", "api_endpoint", APIEndpoint, "api_key", apiKey, "model", model) | ||
return &OpenAIProvider{ | ||
APIKey: apiKey, | ||
Model: model, | ||
} | ||
} | ||
|
||
// Name returns the name of the provider | ||
func (p *OpenAIProvider) Name() string { | ||
return "openai" | ||
} | ||
|
||
// GetChatCompletions get chat completions for ai service | ||
func (p *OpenAIProvider) GetChatCompletions(userInstruction string) (*ai.InvokeResponse, error) { | ||
toolCalls, ok := hasToolCalls() | ||
if !ok { | ||
ylog.Error(ai.ErrNoFunctionCall.Error()) | ||
return &ai.InvokeResponse{Content: "no toolcalls"}, ai.ErrNoFunctionCall | ||
} | ||
|
||
// messages | ||
messages := []ChatCompletionMessage{ | ||
{Role: "system", Content: `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous. If you don't know the answer, stop the conversation by saying "no func call".`}, | ||
{Role: "user", Content: userInstruction}, | ||
} | ||
|
||
body := ReqBody{Model: p.Model, Messages: messages, Tools: toolCalls} | ||
ylog.Debug("request", "tools", len(toolCalls), "messages", messages) | ||
|
||
jsonBody, err := json.Marshal(body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
req, err := http.NewRequest("POST", APIEndpoint, bytes.NewBuffer(jsonBody)) | ||
if err != nil { | ||
return nil, err | ||
} | ||
req.Header.Set("Content-Type", "application/json") | ||
// OpenAI authentication | ||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.APIKey)) | ||
|
||
client := &http.Client{} | ||
resp, err := client.Do(req) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer resp.Body.Close() | ||
|
||
respBody, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return nil, err | ||
} | ||
ylog.Debug("response", "body", respBody) | ||
|
||
// ylog.Info("response body", "body", string(respBody)) | ||
if resp.StatusCode >= 400 { | ||
return nil, fmt.Errorf("ai response status code is %d", resp.StatusCode) | ||
} | ||
|
||
var respBodyStruct RespBody | ||
err = json.Unmarshal(respBody, &respBodyStruct) | ||
if err != nil { | ||
return nil, err | ||
} | ||
// TODO: record usage | ||
// usage := respBodyStruct.Usage | ||
// log.Printf("Token Usage: %+v\n", usage) | ||
|
||
choice := respBodyStruct.Choices[0] | ||
ylog.Debug(">>finish_reason", "reason", choice.FinishReason) | ||
|
||
calls := respBodyStruct.Choices[0].Message.ToolCalls | ||
content := respBodyStruct.Choices[0].Message.Content | ||
|
||
ylog.Debug("--response calls", "calls", calls) | ||
|
||
result := &ai.InvokeResponse{} | ||
if len(calls) == 0 { | ||
result.Content = content | ||
return result, ai.ErrNoFunctionCall | ||
} | ||
|
||
// functions may be more than one | ||
// slog.Info("tool calls", "calls", calls, "mapTools", mapTools) | ||
for _, call := range calls { | ||
fns.Range(func(_, value any) bool { | ||
fn := value.(*connectedFn) | ||
if fn.tc.Equal(&call) { | ||
// Use toolCalls because tool_id is required in the following llm request | ||
if result.ToolCalls == nil { | ||
result.ToolCalls = make(map[uint32][]*ai.ToolCall) | ||
} | ||
// Create a new variable to hold the current call | ||
currentCall := call | ||
result.ToolCalls[fn.tag] = append(result.ToolCalls[fn.tag], ¤tCall) | ||
} | ||
return true | ||
}) | ||
} | ||
|
||
// sfn maybe disconnected, so we need to check if there is any function call | ||
if len(result.ToolCalls) == 0 { | ||
return nil, ai.ErrNoFunctionCall | ||
} | ||
return result, nil | ||
} | ||
|
||
// RegisterFunction register function | ||
func (p *OpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64) error { | ||
fns.Store(connID, &connectedFn{ | ||
connID: connID, | ||
tag: tag, | ||
tc: ai.ToolCall{ | ||
Type: "function", | ||
Function: functionDefinition, | ||
}, | ||
}) | ||
|
||
return nil | ||
} | ||
|
||
// UnregisterFunction unregister function | ||
// Be careful: a function can have multiple instances, remove the offline instance only. | ||
func (p *OpenAIProvider) UnregisterFunction(name string, connID uint64) error { | ||
fns.Delete(connID) | ||
return nil | ||
} | ||
|
||
// ListToolCalls list tool functions | ||
func (p *OpenAIProvider) ListToolCalls() (map[uint32]ai.ToolCall, error) { | ||
tmp := make(map[uint32]ai.ToolCall) | ||
fns.Range(func(_, value any) bool { | ||
fn := value.(*connectedFn) | ||
tmp[fn.tag] = fn.tc | ||
return true | ||
}) | ||
|
||
return tmp, nil | ||
} | ||
|
||
// GetOverview get overview for ai service | ||
func (p *OpenAIProvider) GetOverview() (*ai.OverviewResponse, error) { | ||
result := &ai.OverviewResponse{ | ||
Functions: make(map[uint32]*ai.FunctionDefinition), | ||
} | ||
_, ok := hasToolCalls() | ||
if !ok { | ||
return result, nil | ||
} | ||
|
||
fns.Range(func(_, value any) bool { | ||
fn := value.(*connectedFn) | ||
result.Functions[fn.tag] = fn.tc.Function | ||
return true | ||
}) | ||
|
||
return result, nil | ||
} | ||
|
||
// hasToolCalls check if there are tool calls | ||
func hasToolCalls() ([]ai.ToolCall, bool) { | ||
toolCalls := make([]ai.ToolCall, 0) | ||
fns.Range(func(_, value any) bool { | ||
fn := value.(*connectedFn) | ||
toolCalls = append(toolCalls, fn.tc) | ||
return true | ||
}) | ||
return toolCalls, len(toolCalls) > 0 | ||
} |