diff --git a/ai/function_call.go b/ai/function_call.go index 482f32ed2..dafc85312 100644 --- a/ai/function_call.go +++ b/ai/function_call.go @@ -43,7 +43,7 @@ func (fco *FunctionCall) Bytes() ([]byte, error) { // FromBytes deserialize the FunctionCallObject from the given []byte func (fco *FunctionCall) FromBytes(b []byte) error { - var obj = &FunctionCall{} + obj := &FunctionCall{} err := json.Unmarshal(b, &obj) if err != nil { return err diff --git a/cli/serve.go b/cli/serve.go index 463ef4048..effe3f8ff 100644 --- a/cli/serve.go +++ b/cli/serve.go @@ -29,6 +29,7 @@ import ( "github.com/yomorun/yomo/pkg/bridge/ai" "github.com/yomorun/yomo/pkg/bridge/ai/provider/azopenai" + "github.com/yomorun/yomo/pkg/bridge/ai/provider/openai" ) // serveCmd represents the serve command @@ -133,6 +134,10 @@ func registerAIProvider(aiConfig *ai.Config) { log.InfoStatusEvent(os.Stdout, "register [%s] AI provider", name) // TODO: register other providers } + // register the OpenAI provider + if name == "openai" { + ai.RegisterProvider(openai.NewProvider(provider["api_key"], provider["model"])) + } } } diff --git a/example/10-ai/zipper.yaml b/example/10-ai/zipper.yaml index c7a1945db..ddbaef0b6 100644 --- a/example/10-ai/zipper.yaml +++ b/example/10-ai/zipper.yaml @@ -21,7 +21,7 @@ bridge: openai: api_key: - api_endpoint: + model: gemini: api_key: diff --git a/pkg/bridge/ai/provider/openai/provider.go b/pkg/bridge/ai/provider/openai/provider.go new file mode 100644 index 000000000..017d5a4af --- /dev/null +++ b/pkg/bridge/ai/provider/openai/provider.go @@ -0,0 +1,270 @@ +package openai + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "sync" + + _ "github.com/joho/godotenv/autoload" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/core/ylog" +) + +const APIEndpoint = "https://api.openai.com/v1/chat/completions" + +var fns sync.Map + +// Message +type ChatCompletionMessage struct { + Role string `json:"role"` + Content string `json:"content"` + // - https://github.com/openai/openai-python/blob/main/chatml.md + // - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + Name string `json:"name,omitempty"` + // MultiContent []ChatMessagePart + // For Role=assistant prompts this may be set to the tool calls generated by the model, such as function calls. + ToolCalls []ai.ToolCall `json:"tool_calls,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` +} + +// RequestBody is the request body +type ReqBody struct { + Model string `json:"model"` + Messages []ChatCompletionMessage `json:"messages"` + Tools []ai.ToolCall `json:"tools"` // chatCompletionTool + // ToolChoice string `json:"tool_choice"` // chatCompletionFunction +} + +// Resp is the response body +type RespBody struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + Model string `json:"model"` + Choices []RespChoice `json:"choices"` + Usage RespUsage `json:"usage"` + SystemFingerprint string `json:"system_fingerprint"` +} + +// RespMessage is the message in Response +type RespMessage struct { + Role string `json:"role"` + Content string `json:"content"` + ToolCalls []ai.ToolCall `json:"tool_calls"` +} + +// RespChoice is used to indicate the choice in Response by `FinishReason` +type RespChoice struct { + FinishReason string `json:"finish_reason"` + Index int `json:"index"` + Message ChatCompletionMessage `json:"message"` +} + +// RespUsage is the token usage in Response +type RespUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// OpenAIProvider is the provider for OpenAI +type OpenAIProvider struct { + // APIKey is the API key for OpenAI + APIKey string + // Model is the model for OpenAI + // eg. "gpt-3.5-turbo-1106", "gpt-4-turbo-preview", "gpt-4-vision-preview", "gpt-4" + Model string +} + +type connectedFn struct { + connID uint64 + tag uint32 + tc ai.ToolCall +} + +func init() { + fns = sync.Map{} +} + +// NewProvider creates a new OpenAIProvider +func NewProvider(apiKey string, model string) *OpenAIProvider { + if apiKey == "" { + apiKey = os.Getenv("OPENAI_API_KEY") + } + if model == "" { + model = os.Getenv("OPENAI_MODEL") + } + ylog.Debug("new openai provider", "api_endpoint", APIEndpoint, "api_key", apiKey, "model", model) + return &OpenAIProvider{ + APIKey: apiKey, + Model: model, + } +} + +// Name returns the name of the provider +func (p *OpenAIProvider) Name() string { + return "openai" +} + +// GetChatCompletions get chat completions for ai service +func (p *OpenAIProvider) GetChatCompletions(userInstruction string) (*ai.InvokeResponse, error) { + toolCalls, ok := hasToolCalls() + if !ok { + ylog.Error(ai.ErrNoFunctionCall.Error()) + return &ai.InvokeResponse{Content: "no toolcalls"}, ai.ErrNoFunctionCall + } + + // messages + messages := []ChatCompletionMessage{ + {Role: "system", Content: `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous. If you don't know the answer, stop the conversation by saying "no func call".`}, + {Role: "user", Content: userInstruction}, + } + + body := ReqBody{Model: p.Model, Messages: messages, Tools: toolCalls} + ylog.Debug("request", "tools", len(toolCalls), "messages", messages) + + jsonBody, err := json.Marshal(body) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", APIEndpoint, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + // OpenAI authentication + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", p.APIKey)) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + ylog.Debug("response", "body", respBody) + + // ylog.Info("response body", "body", string(respBody)) + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("ai response status code is %d", resp.StatusCode) + } + + var respBodyStruct RespBody + err = json.Unmarshal(respBody, &respBodyStruct) + if err != nil { + return nil, err + } + // TODO: record usage + // usage := respBodyStruct.Usage + // log.Printf("Token Usage: %+v\n", usage) + + choice := respBodyStruct.Choices[0] + ylog.Debug(">>finish_reason", "reason", choice.FinishReason) + + calls := respBodyStruct.Choices[0].Message.ToolCalls + content := respBodyStruct.Choices[0].Message.Content + + ylog.Debug("--response calls", "calls", calls) + + result := &ai.InvokeResponse{} + if len(calls) == 0 { + result.Content = content + return result, ai.ErrNoFunctionCall + } + + // functions may be more than one + // slog.Info("tool calls", "calls", calls, "mapTools", mapTools) + for _, call := range calls { + fns.Range(func(_, value any) bool { + fn := value.(*connectedFn) + if fn.tc.Equal(&call) { + // Use toolCalls because tool_id is required in the following llm request + if result.ToolCalls == nil { + result.ToolCalls = make(map[uint32][]*ai.ToolCall) + } + // Create a new variable to hold the current call + currentCall := call + result.ToolCalls[fn.tag] = append(result.ToolCalls[fn.tag], ¤tCall) + } + return true + }) + } + + // sfn maybe disconnected, so we need to check if there is any function call + if len(result.ToolCalls) == 0 { + return nil, ai.ErrNoFunctionCall + } + return result, nil +} + +// RegisterFunction register function +func (p *OpenAIProvider) RegisterFunction(tag uint32, functionDefinition *ai.FunctionDefinition, connID uint64) error { + fns.Store(connID, &connectedFn{ + connID: connID, + tag: tag, + tc: ai.ToolCall{ + Type: "function", + Function: functionDefinition, + }, + }) + + return nil +} + +// UnregisterFunction unregister function +// Be careful: a function can have multiple instances, remove the offline instance only. +func (p *OpenAIProvider) UnregisterFunction(name string, connID uint64) error { + fns.Delete(connID) + return nil +} + +// ListToolCalls list tool functions +func (p *OpenAIProvider) ListToolCalls() (map[uint32]ai.ToolCall, error) { + tmp := make(map[uint32]ai.ToolCall) + fns.Range(func(_, value any) bool { + fn := value.(*connectedFn) + tmp[fn.tag] = fn.tc + return true + }) + + return tmp, nil +} + +// GetOverview get overview for ai service +func (p *OpenAIProvider) GetOverview() (*ai.OverviewResponse, error) { + result := &ai.OverviewResponse{ + Functions: make(map[uint32]*ai.FunctionDefinition), + } + _, ok := hasToolCalls() + if !ok { + return result, nil + } + + fns.Range(func(_, value any) bool { + fn := value.(*connectedFn) + result.Functions[fn.tag] = fn.tc.Function + return true + }) + + return result, nil +} + +// hasToolCalls check if there are tool calls +func hasToolCalls() ([]ai.ToolCall, bool) { + toolCalls := make([]ai.ToolCall, 0) + fns.Range(func(_, value any) bool { + fn := value.(*connectedFn) + toolCalls = append(toolCalls, fn.tc) + return true + }) + return toolCalls, len(toolCalls) > 0 +}