diff --git a/ai/model.go b/ai/model.go index a54d0fecb..a8db6d31f 100644 --- a/ai/model.go +++ b/ai/model.go @@ -3,6 +3,11 @@ package ai import "errors" +// ErrorResponse is the response for error +type ErrorResponse struct { + Error string `json:"error"` +} + // OverviewResponse is the response for overview type OverviewResponse struct { Functions map[uint32]*FunctionDefinition // key is the tag of yomo @@ -10,18 +15,21 @@ type OverviewResponse struct { // InvokeRequest is the request from user to BasicAPIServer type InvokeRequest struct { - ReqID string `json:"req_id"` // ReqID is the request id of the request - Prompt string `json:"prompt"` // Prompt is user input text for chat completion + ReqID string `json:"req_id"` // ReqID is the request id of the request + Prompt string `json:"prompt"` // Prompt is user input text for chat completion + IncludeCallStack bool `json:"include_call_stack"` // IncludeCallStack is the flag to include call stack in response } // InvokeResponse is the response for chat completions type InvokeResponse struct { // Functions is the functions from llm api response, key is the tag of yomo - Functions map[uint32][]*FunctionDefinition + // Functions map[uint32][]*FunctionDefinition // Content is the content from llm api response Content string // ToolCalls is the toolCalls from llm api response ToolCalls map[uint32][]*ToolCall + // ToolMessages is the tool messages from llm api response + ToolMessages []ToolMessage // FinishReason is the finish reason from llm api response FinishReason string // TokenUsage is the token usage from llm api response diff --git a/cli/Taskfile.yml b/cli/Taskfile.yml index aa5c38637..3fe62e28f 100644 --- a/cli/Taskfile.yml +++ b/cli/Taskfile.yml @@ -38,7 +38,7 @@ tasks: dir: ../cmd/yomo cmds: - echo "{{.Name}} install..." - - go install -ldflags "-s -w -X {{.Module}}.Version={{.Version}} -X {{.Module}}.Date={{.Date}}" + - go install -race -ldflags "-s -w -X {{.Module}}.Version={{.Version}} -X {{.Module}}.Date={{.Date}}" - echo "{{.Name}} {{.Version}}({{.Date}}) is installed." silent: true diff --git a/cli/test.go b/cli/test.go new file mode 100644 index 000000000..9bc823cd1 --- /dev/null +++ b/cli/test.go @@ -0,0 +1,226 @@ +/* +Copyright © 2021 Allegro Networks + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package cli + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "net/http" + "os" + "os/exec" + "strings" + "syscall" + "time" + + "github.com/spf13/cobra" + "github.com/yomorun/yomo/ai" + "github.com/yomorun/yomo/pkg/log" + + // serverless registrations + _ "github.com/yomorun/yomo/cli/serverless/deno" + _ "github.com/yomorun/yomo/cli/serverless/golang" + _ "github.com/yomorun/yomo/cli/serverless/wasm" +) + +var ( + sfnDir []string + userPrompt string + systemPrompt string + aiServerAddr string +) + +// testPromptCmd represents the test prompt command for LLM function +// the source code of the LLM function is in the sfnDir +var testPromptCmd = &cobra.Command{ + Use: "test-prompt", + Aliases: []string{"p"}, + Short: "Test LLM prompt", + Long: "Test LLM prompt", + Run: func(cmd *cobra.Command, args []string) { + // sfn source directory + if len(sfnDir) == 0 { + sfnDir = append(sfnDir, ".") + } + for _, dir := range sfnDir { + // run sfn + log.InfoStatusEvent(os.Stdout, "--------------------------------------------------------") + log.InfoStatusEvent(os.Stdout, "Attaching LLM function in directory: %v", dir) + cmd := exec.Command("go", "run", ".") + cmd.Dir = dir + env := os.Environ() + env = append(env, "YOMO_LOG_LEVEL=info") + cmd.Env = env + // cmd.Stdout = io.Discard + // cmd.Stderr = io.Discard + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + } + stdout, err := cmd.StdoutPipe() + if err != nil { + log.FailureStatusEvent(os.Stdout, "Failed to attach LLM function in directory: %v, error: %v", dir, err) + continue + } + defer stdout.Close() + outputReader := bufio.NewReader(stdout) + // read outputReader + output := make(chan string) + defer close(output) + go func(outputReader *bufio.Reader, output chan string) { + for { + line, err := outputReader.ReadString('\n') + if err != nil { + break + } + if len(line) > 0 { + output <- line + } + } + }(outputReader, output) + // start cmd + if err := cmd.Start(); err != nil { + log.FailureStatusEvent(os.Stdout, "Failed to run LLM function in directory: %v, error: %v", dir, err) + continue + } else { + defer func(cmd *exec.Cmd) { + pgid, err := syscall.Getpgid(cmd.Process.Pid) + if err == nil { + syscall.Kill(-pgid, syscall.SIGTERM) + } else { + cmd.Process.Kill() + } + }(cmd) + } + // wait for the sfn to be ready + for { + select { + case out := <-output: + // log.InfoStatusEvent(os.Stdout, "AI SFN Output: %s", out) + if len(out) > 0 && strings.Contains(out, "register ai function success") { + log.InfoStatusEvent(os.Stdout, "Register LLM function success") + goto REQUEST + } + case <-time.After(5 * time.Second): + log.FailureStatusEvent(os.Stdout, "Connect to zipper failed, please check the zipper is running or not") + os.Exit(1) + } + } + // invoke llm api + // request + REQUEST: + apiEndpoint := fmt.Sprintf("%s/invoke", aiServerAddr) + log.InfoStatusEvent(os.Stdout, `Invoking LLM API "%s"`, apiEndpoint) + invokeReq := ai.InvokeRequest{ + IncludeCallStack: true, // include call stack + Prompt: userPrompt, + } + reqBuf, err := json.Marshal(invokeReq) + if err != nil { + log.FailureStatusEvent(os.Stdout, "Failed to marshal invoke request: %v", err) + continue + } + // invoke api endpoint + log.InfoStatusEvent(os.Stdout, ">> LLM API Request") + log.InfoStatusEvent(os.Stdout, "Messages:") + log.InfoStatusEvent(os.Stdout, "\tSystem: %s", systemPrompt) + log.InfoStatusEvent(os.Stdout, "\tUser: %s", userPrompt) + resp, err := http.Post(apiEndpoint, "application/json", bytes.NewBuffer(reqBuf)) + if err != nil { + log.FailureStatusEvent(os.Stdout, "Failed to invoke llm api: %v", err) + continue + } + defer resp.Body.Close() + // response + // failed to invoke llm api + log.InfoStatusEvent(os.Stdout, "<< LLM API Response") + if resp.StatusCode != http.StatusOK { + var errorResp ai.ErrorResponse + err := json.NewDecoder(resp.Body).Decode(&errorResp) + if err != nil { + log.FailureStatusEvent(os.Stdout, "Failed to decode LLM API response: %v", err) + continue + } + log.FailureStatusEvent(os.Stdout, "Failed to invoke LLM API response: %s", errorResp.Error) + continue + } + // success to invoke LLM API + var invokeResp ai.InvokeResponse + if err := json.NewDecoder(resp.Body).Decode(&invokeResp); err != nil { + log.FailureStatusEvent(os.Stdout, "Failed to decode LLM API response: %v", err) + continue + } + // tool calls + for tag, tcs := range invokeResp.ToolCalls { + toolCallCount := len(tcs) + if toolCallCount > 0 { + log.InfoStatusEvent(os.Stdout, "Invoking functions[%d]:", toolCallCount) + for _, tc := range tcs { + if invokeResp.ToolMessages == nil { + log.InfoStatusEvent(os.Stdout, + "\t[%s] tag: %d, name: %s, arguments: %s", + tc.ID, + tag, + tc.Function.Name, + tc.Function.Arguments, + ) + } else { + log.InfoStatusEvent(os.Stdout, + "\t[%s] tag: %d, name: %s, arguments: %s\n🌟 result: %s", + tc.ID, + tag, + tc.Function.Name, + tc.Function.Arguments, + getToolCallResult(tc, invokeResp.ToolMessages), + ) + } + } + } + } + // finish reason + log.InfoStatusEvent(os.Stdout, "Finish Reason: %s", invokeResp.FinishReason) + log.InfoStatusEvent(os.Stdout, "Final Content: \n🤖 %s", invokeResp.Content) + } + }, +} + +func getToolCallResult(tc *ai.ToolCall, tms []ai.ToolMessage) string { + result := "" + for _, tm := range tms { + if tm.ToolCallId == tc.ID { + result = tm.Content + } + } + return result +} + +func init() { + rootCmd.AddCommand(testPromptCmd) + + testPromptCmd.Flags().StringSliceVarP(&sfnDir, "sfn", "", []string{}, "sfn source directory") + testPromptCmd.Flags().StringVarP(&userPrompt, "user-prompt", "u", "", "user prompt") + testPromptCmd.MarkFlagRequired("user-prompt") + testPromptCmd.Flags().StringVarP( + &systemPrompt, + "system-prompt", + "s", + `You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.`, + "system prompt", + ) + testPromptCmd.Flags().StringVarP(&aiServerAddr, "ai-server", "a", "http://localhost:8000", "LLM API server address") + + runViper = bindViper(testPromptCmd) +} diff --git a/example/10-ai/llm-sfn-timezone-calculator/main.go b/example/10-ai/llm-sfn-timezone-calculator/main.go index 996e74800..3ca3bbb5a 100644 --- a/example/10-ai/llm-sfn-timezone-calculator/main.go +++ b/example/10-ai/llm-sfn-timezone-calculator/main.go @@ -4,6 +4,7 @@ import ( "fmt" "log/slog" "os" + "strings" "time" "github.com/yomorun/yomo" @@ -18,7 +19,7 @@ type Parameter struct { } func Description() string { - return `if user asks timezone converter related questions, extract the source time and timezone information to "timeString" and "sourceTimezone", extract the target timezone information to "targetTimezone". the desired "timeString" format is "YYYY-MM-DD HH:MM:SS". the "sourceTimezone" and "targetTimezone" are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format "YYYY-MM-DD HH:MM:SS". If you are not sure about the date value of "timeString", set date value to "1900-01-01"` + return `if user asks timezone converter related questions, extract the source time and timezone information to "timeString" and "sourceTimezone", extract the target timezone information to "targetTimezone". the desired "timeString" format is "YYYY-MM-DD HH:MM:SS". the "sourceTimezone" and "targetTimezone" are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format "YYYY-MM-DD HH:MM:SS". If you are not sure about the date value of "timeString", you pretend date as today.` } func InputSchema() any { @@ -75,6 +76,11 @@ func handler(ctx serverless.Context) { msg.TargetTimezone = "UTC" } + // should gurantee date will not be "YYYY-MM-DD" + if strings.Contains(msg.TimeString, "YYYY-MM-DD") { + msg.TimeString = strings.ReplaceAll(msg.TimeString, "YYYY-MM-DD", time.Now().Format("2006-01-02")) + } + targetTime, err := ConvertTimezone(msg.TimeString, msg.SourceTimezone, msg.TargetTimezone) if err != nil { slog.Error("[sfn] ConvertTimezone error", "err", err) @@ -84,7 +90,7 @@ func handler(ctx serverless.Context) { slog.Info("[sfn] result", "result", targetTime) - val := fmt.Sprintf("This time in timezone %s is %s", msg.TargetTimezone, targetTime) + val := fmt.Sprintf("This time in timezone %s is %s when %s in %s", msg.TargetTimezone, targetTime, msg.TimeString, msg.SourceTimezone) // fcCtx.SetRetrievalResult(val) fcCtx.Write(val) diff --git a/pkg/bridge/ai/api_server.go b/pkg/bridge/ai/api_server.go index 2e7b5b084..723874a4f 100644 --- a/pkg/bridge/ai/api_server.go +++ b/pkg/bridge/ai/api_server.go @@ -97,7 +97,6 @@ func WithContextService(handler http.Handler, credential string, zipperAddr stri return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler.ServeHTTP(w, r.WithContext(WithServiceContext(r.Context(), service))) }) - } // HandleOverview is the handler for GET /overview @@ -162,7 +161,7 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) { go func() { // call llm to infer the function and arguments to be invoked ylog.Debug(">> ai request", "reqID", req.ReqID, "prompt", req.Prompt) - res, err := service.GetChatCompletions(req.Prompt, baseSystemMessage, reqID) + res, err := service.GetChatCompletions(req.Prompt, baseSystemMessage, reqID, req.IncludeCallStack) if err != nil { errCh <- err } else { diff --git a/pkg/bridge/ai/provider/gemini/model_converter.go b/pkg/bridge/ai/provider/gemini/model_converter.go index f16a6fbec..727927a7a 100644 --- a/pkg/bridge/ai/provider/gemini/model_converter.go +++ b/pkg/bridge/ai/provider/gemini/model_converter.go @@ -116,16 +116,18 @@ func parseToolCallFromResponse(response *Response) []ai.ToolCall { calls := make([]ai.ToolCall, 0) for _, candidate := range response.Candidates { fn := candidate.Content.Parts[0].FunctionCall - fd := &ai.FunctionDefinition{ - Name: fn.Name, - Arguments: generateJSONSchemaArguments(fn.Args), + if fn != nil { + fd := &ai.FunctionDefinition{ + Name: fn.Name, + Arguments: generateJSONSchemaArguments(fn.Args), + } + call := ai.ToolCall{ + ID: "cc-gemini-id", + Type: "cc-function", + Function: fd, + } + calls = append(calls, call) } - call := ai.ToolCall{ - ID: "cc-gemini-id", - Type: "cc-function", - Function: fd, - } - calls = append(calls, call) } return calls } diff --git a/pkg/bridge/ai/provider/gemini/model_response.go b/pkg/bridge/ai/provider/gemini/model_response.go index ed7b207d0..bba877eda 100644 --- a/pkg/bridge/ai/provider/gemini/model_response.go +++ b/pkg/bridge/ai/provider/gemini/model_response.go @@ -23,6 +23,7 @@ type CandidateContent struct { // Part is the element of CandidateContent type Part struct { + Text string `json:"text,omitempty"` FunctionCall *FunctionCall `json:"functionCall"` } diff --git a/pkg/bridge/ai/provider/gemini/provider.go b/pkg/bridge/ai/provider/gemini/provider.go index ce4562ec9..72ddcac2e 100644 --- a/pkg/bridge/ai/provider/gemini/provider.go +++ b/pkg/bridge/ai/provider/gemini/provider.go @@ -113,6 +113,9 @@ func (p *GeminiProvider) GetChatCompletions(userInstruction string, baseSystemMe ylog.Debug("gemini api response", "calls", len(calls)) result := &ai.InvokeResponse{} + result.FinishReason = response.Candidates[0].FinishReason + result.Content = response.Candidates[0].Content.Parts[0].Text + if len(calls) == 0 { return result, ai.ErrNoFunctionCall } @@ -123,6 +126,11 @@ func (p *GeminiProvider) GetChatCompletions(userInstruction string, baseSystemMe if fd.Name == tc.Function.Name { ylog.Debug("-----> add function", "name", fd.Name, "tag", tag) currentCall := tc + fn := response.Candidates[0].Content.Parts[0].FunctionCall + if fn != nil { + args, _ := json.Marshal(fn.Args) + currentCall.Function.Arguments = string(args) + } result.ToolCalls[tag] = append(result.ToolCalls[tag], ¤tCall) } } diff --git a/pkg/bridge/ai/service.go b/pkg/bridge/ai/service.go index d034a5b76..d7329d5eb 100644 --- a/pkg/bridge/ai/service.go +++ b/pkg/bridge/ai/service.go @@ -198,7 +198,7 @@ func (s *Service) GetOverview() (*ai.OverviewResponse, error) { } // GetChatCompletions returns the llm api response -func (s *Service) GetChatCompletions(userInstruction string, baseSystemMessage string, reqID string) (*ai.InvokeResponse, error) { +func (s *Service) GetChatCompletions(userInstruction string, baseSystemMessage string, reqID string, includeCallStack bool) (*ai.InvokeResponse, error) { chainMessage := ai.ChainMessage{} // we do not support multi-turn invoke for Google Gemini if s.LLMProvider.Name() == "gemini" { @@ -226,6 +226,11 @@ func (s *Service) GetChatCompletions(userInstruction string, baseSystemMessage s chainMessage.ToolMessages = llmCalls // do not attach toolMessage to prompt in 2nd call res2, err := s.LLMProvider.GetChatCompletions(userInstruction, baseSystemMessage, chainMessage, s.md, false) + // INFO: call stack infomation + if includeCallStack { + res2.ToolCalls = res.ToolCalls + res2.ToolMessages = llmCalls + } ylog.Debug("<<<< complete 2nd call", "res2", fmt.Sprintf("%+v", res2)) return res2, err