Skip to content

Commit

Permalink
feat: cli test prompt (#760)
Browse files Browse the repository at this point in the history
# Description

Used to test ai sfn.

## Usage
```sh
$ yomo test-prompt -h

Test LLM prompt

Usage:
  yomo test-prompt [flags]

Aliases:
  test-prompt, p

Flags:
  -a, --ai-server string       LLM API server address (default "http://localhost:8000")
  -h, --help                   help for test-prompt
      --sfn strings            sfn source directory
  -s, --system-prompt string   system prompt (default "You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous. If you don't know the answer, stop the conversation by saying \"no func call\"")
  -u, --user-prompt string     user prompt
```

Test
[llm-sfn-get-weather](https://github.com/yomorun/yomo/tree/master/example/10-ai/llm-sfn-get-weather/main.go)
```sh
$ yomo p --sfn ./llm-sfn-get-weather -u "What's the difference between the weather in Beijing and New York?"
ℹ️   --------------------------------------------------------
ℹ️   Run AI SFN on directory: .
ℹ️   Register AI function success
ℹ️   Invoke LLM API "http://localhost:8000/invoke"
ℹ️   >> LLM API Request
ℹ️   Messages:
ℹ️       System: You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.
ℹ️       User: What's the difference between the weather in Beijing and New York?
ℹ️   << LLM API Response
ℹ️   Invoke functions[2]:
ℹ️       [call_LUD1TeeQhnh7DgOpEWBPEvub] tag: 17, name: get-weather, arguments: {"city_name": "Beijing"}, result: [Beijing] temperature: 25°C
ℹ️       [call_Ml5GMJNoJflFvloAIVFfn9eo] tag: 17, name: get-weather, arguments: {"city_name": "New York"}, result: [New York] temperature: 30°C
ℹ️   Finish Reason: stop
ℹ️   Content: The current temperature in Beijing is 25°C, while the temperature in New York is 30°C. This means that New York is currently 5°C warmer than Beijing.

```

---------

Co-authored-by: C.C <[email protected]>
  • Loading branch information
venjiang and fanweixiao authored Mar 21, 2024
1 parent 1d7d54e commit 8b24008
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 18 deletions.
14 changes: 11 additions & 3 deletions ai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,33 @@ package ai

import "errors"

// ErrorResponse is the response for error
type ErrorResponse struct {
Error string `json:"error"`
}

// OverviewResponse is the response for overview
type OverviewResponse struct {
Functions map[uint32]*FunctionDefinition // key is the tag of yomo
}

// InvokeRequest is the request from user to BasicAPIServer
type InvokeRequest struct {
ReqID string `json:"req_id"` // ReqID is the request id of the request
Prompt string `json:"prompt"` // Prompt is user input text for chat completion
ReqID string `json:"req_id"` // ReqID is the request id of the request
Prompt string `json:"prompt"` // Prompt is user input text for chat completion
IncludeCallStack bool `json:"include_call_stack"` // IncludeCallStack is the flag to include call stack in response
}

// InvokeResponse is the response for chat completions
type InvokeResponse struct {
// Functions is the functions from llm api response, key is the tag of yomo
Functions map[uint32][]*FunctionDefinition
// Functions map[uint32][]*FunctionDefinition
// Content is the content from llm api response
Content string
// ToolCalls is the toolCalls from llm api response
ToolCalls map[uint32][]*ToolCall
// ToolMessages is the tool messages from llm api response
ToolMessages []ToolMessage
// FinishReason is the finish reason from llm api response
FinishReason string
// TokenUsage is the token usage from llm api response
Expand Down
2 changes: 1 addition & 1 deletion cli/Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ tasks:
dir: ../cmd/yomo
cmds:
- echo "{{.Name}} install..."
- go install -ldflags "-s -w -X {{.Module}}.Version={{.Version}} -X {{.Module}}.Date={{.Date}}"
- go install -race -ldflags "-s -w -X {{.Module}}.Version={{.Version}} -X {{.Module}}.Date={{.Date}}"
- echo "{{.Name}} {{.Version}}({{.Date}}) is installed."
silent: true

Expand Down
226 changes: 226 additions & 0 deletions cli/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
Copyright © 2021 Allegro Networks
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package cli

import (
"bufio"
"bytes"
"encoding/json"
"fmt"
"net/http"
"os"
"os/exec"
"strings"
"syscall"
"time"

"github.com/spf13/cobra"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/pkg/log"

// serverless registrations
_ "github.com/yomorun/yomo/cli/serverless/deno"
_ "github.com/yomorun/yomo/cli/serverless/golang"
_ "github.com/yomorun/yomo/cli/serverless/wasm"
)

var (
sfnDir []string
userPrompt string
systemPrompt string
aiServerAddr string
)

// testPromptCmd represents the test prompt command for LLM function
// the source code of the LLM function is in the sfnDir
var testPromptCmd = &cobra.Command{
Use: "test-prompt",
Aliases: []string{"p"},
Short: "Test LLM prompt",
Long: "Test LLM prompt",
Run: func(cmd *cobra.Command, args []string) {
// sfn source directory
if len(sfnDir) == 0 {
sfnDir = append(sfnDir, ".")
}
for _, dir := range sfnDir {
// run sfn
log.InfoStatusEvent(os.Stdout, "--------------------------------------------------------")
log.InfoStatusEvent(os.Stdout, "Attaching LLM function in directory: %v", dir)
cmd := exec.Command("go", "run", ".")
cmd.Dir = dir
env := os.Environ()
env = append(env, "YOMO_LOG_LEVEL=info")
cmd.Env = env
// cmd.Stdout = io.Discard
// cmd.Stderr = io.Discard
cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true,

Check failure on line 71 in cli/test.go

View workflow job for this annotation

GitHub Actions / Build and release

unknown field Setpgid in struct literal of type "syscall".SysProcAttr
}
stdout, err := cmd.StdoutPipe()
if err != nil {
log.FailureStatusEvent(os.Stdout, "Failed to attach LLM function in directory: %v, error: %v", dir, err)
continue
}
defer stdout.Close()
outputReader := bufio.NewReader(stdout)
// read outputReader
output := make(chan string)
defer close(output)
go func(outputReader *bufio.Reader, output chan string) {
for {
line, err := outputReader.ReadString('\n')
if err != nil {
break
}
if len(line) > 0 {
output <- line
}
}
}(outputReader, output)
// start cmd
if err := cmd.Start(); err != nil {
log.FailureStatusEvent(os.Stdout, "Failed to run LLM function in directory: %v, error: %v", dir, err)
continue
} else {
defer func(cmd *exec.Cmd) {
pgid, err := syscall.Getpgid(cmd.Process.Pid)

Check failure on line 100 in cli/test.go

View workflow job for this annotation

GitHub Actions / Build and release

undefined: syscall.Getpgid
if err == nil {
syscall.Kill(-pgid, syscall.SIGTERM)

Check failure on line 102 in cli/test.go

View workflow job for this annotation

GitHub Actions / Build and release

undefined: syscall.Kill
} else {
cmd.Process.Kill()
}
}(cmd)
}
// wait for the sfn to be ready
for {
select {
case out := <-output:
// log.InfoStatusEvent(os.Stdout, "AI SFN Output: %s", out)
if len(out) > 0 && strings.Contains(out, "register ai function success") {
log.InfoStatusEvent(os.Stdout, "Register LLM function success")
goto REQUEST
}
case <-time.After(5 * time.Second):
log.FailureStatusEvent(os.Stdout, "Connect to zipper failed, please check the zipper is running or not")
os.Exit(1)
}
}
// invoke llm api
// request
REQUEST:
apiEndpoint := fmt.Sprintf("%s/invoke", aiServerAddr)
log.InfoStatusEvent(os.Stdout, `Invoking LLM API "%s"`, apiEndpoint)
invokeReq := ai.InvokeRequest{
IncludeCallStack: true, // include call stack
Prompt: userPrompt,
}
reqBuf, err := json.Marshal(invokeReq)
if err != nil {
log.FailureStatusEvent(os.Stdout, "Failed to marshal invoke request: %v", err)
continue
}
// invoke api endpoint
log.InfoStatusEvent(os.Stdout, ">> LLM API Request")
log.InfoStatusEvent(os.Stdout, "Messages:")
log.InfoStatusEvent(os.Stdout, "\tSystem: %s", systemPrompt)
log.InfoStatusEvent(os.Stdout, "\tUser: %s", userPrompt)
resp, err := http.Post(apiEndpoint, "application/json", bytes.NewBuffer(reqBuf))
if err != nil {
log.FailureStatusEvent(os.Stdout, "Failed to invoke llm api: %v", err)
continue
}
defer resp.Body.Close()
// response
// failed to invoke llm api
log.InfoStatusEvent(os.Stdout, "<< LLM API Response")
if resp.StatusCode != http.StatusOK {
var errorResp ai.ErrorResponse
err := json.NewDecoder(resp.Body).Decode(&errorResp)
if err != nil {
log.FailureStatusEvent(os.Stdout, "Failed to decode LLM API response: %v", err)
continue
}
log.FailureStatusEvent(os.Stdout, "Failed to invoke LLM API response: %s", errorResp.Error)
continue
}
// success to invoke LLM API
var invokeResp ai.InvokeResponse
if err := json.NewDecoder(resp.Body).Decode(&invokeResp); err != nil {
log.FailureStatusEvent(os.Stdout, "Failed to decode LLM API response: %v", err)
continue
}
// tool calls
for tag, tcs := range invokeResp.ToolCalls {
toolCallCount := len(tcs)
if toolCallCount > 0 {
log.InfoStatusEvent(os.Stdout, "Invoking functions[%d]:", toolCallCount)
for _, tc := range tcs {
if invokeResp.ToolMessages == nil {
log.InfoStatusEvent(os.Stdout,
"\t[%s] tag: %d, name: %s, arguments: %s",
tc.ID,
tag,
tc.Function.Name,
tc.Function.Arguments,
)
} else {
log.InfoStatusEvent(os.Stdout,
"\t[%s] tag: %d, name: %s, arguments: %s\n🌟 result: %s",
tc.ID,
tag,
tc.Function.Name,
tc.Function.Arguments,
getToolCallResult(tc, invokeResp.ToolMessages),
)
}
}
}
}
// finish reason
log.InfoStatusEvent(os.Stdout, "Finish Reason: %s", invokeResp.FinishReason)
log.InfoStatusEvent(os.Stdout, "Final Content: \n🤖 %s", invokeResp.Content)
}
},
}

func getToolCallResult(tc *ai.ToolCall, tms []ai.ToolMessage) string {
result := ""
for _, tm := range tms {
if tm.ToolCallId == tc.ID {
result = tm.Content
}
}
return result
}

func init() {
rootCmd.AddCommand(testPromptCmd)

testPromptCmd.Flags().StringSliceVarP(&sfnDir, "sfn", "", []string{}, "sfn source directory")
testPromptCmd.Flags().StringVarP(&userPrompt, "user-prompt", "u", "", "user prompt")
testPromptCmd.MarkFlagRequired("user-prompt")
testPromptCmd.Flags().StringVarP(
&systemPrompt,
"system-prompt",
"s",
`You are a very helpful assistant. Your job is to choose the best possible action to solve the user question or task. Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous.`,
"system prompt",
)
testPromptCmd.Flags().StringVarP(&aiServerAddr, "ai-server", "a", "http://localhost:8000", "LLM API server address")

runViper = bindViper(testPromptCmd)
}
10 changes: 8 additions & 2 deletions example/10-ai/llm-sfn-timezone-calculator/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"log/slog"
"os"
"strings"
"time"

"github.com/yomorun/yomo"
Expand All @@ -18,7 +19,7 @@ type Parameter struct {
}

func Description() string {
return `if user asks timezone converter related questions, extract the source time and timezone information to "timeString" and "sourceTimezone", extract the target timezone information to "targetTimezone". the desired "timeString" format is "YYYY-MM-DD HH:MM:SS". the "sourceTimezone" and "targetTimezone" are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format "YYYY-MM-DD HH:MM:SS". If you are not sure about the date value of "timeString", set date value to "1900-01-01"`
return `if user asks timezone converter related questions, extract the source time and timezone information to "timeString" and "sourceTimezone", extract the target timezone information to "targetTimezone". the desired "timeString" format is "YYYY-MM-DD HH:MM:SS". the "sourceTimezone" and "targetTimezone" are in IANA Time Zone Database identifier format. The function will convert the time from the source timezone to the target timezone and return the converted time as a string in the format "YYYY-MM-DD HH:MM:SS". If you are not sure about the date value of "timeString", you pretend date as today.`
}

func InputSchema() any {
Expand Down Expand Up @@ -75,6 +76,11 @@ func handler(ctx serverless.Context) {
msg.TargetTimezone = "UTC"
}

// should gurantee date will not be "YYYY-MM-DD"
if strings.Contains(msg.TimeString, "YYYY-MM-DD") {
msg.TimeString = strings.ReplaceAll(msg.TimeString, "YYYY-MM-DD", time.Now().Format("2006-01-02"))
}

targetTime, err := ConvertTimezone(msg.TimeString, msg.SourceTimezone, msg.TargetTimezone)
if err != nil {
slog.Error("[sfn] ConvertTimezone error", "err", err)
Expand All @@ -84,7 +90,7 @@ func handler(ctx serverless.Context) {

slog.Info("[sfn] result", "result", targetTime)

val := fmt.Sprintf("This time in timezone %s is %s", msg.TargetTimezone, targetTime)
val := fmt.Sprintf("This time in timezone %s is %s when %s in %s", msg.TargetTimezone, targetTime, msg.TimeString, msg.SourceTimezone)

// fcCtx.SetRetrievalResult(val)
fcCtx.Write(val)
Expand Down
3 changes: 1 addition & 2 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ func WithContextService(handler http.Handler, credential string, zipperAddr stri
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handler.ServeHTTP(w, r.WithContext(WithServiceContext(r.Context(), service)))
})

}

// HandleOverview is the handler for GET /overview
Expand Down Expand Up @@ -162,7 +161,7 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) {
go func() {
// call llm to infer the function and arguments to be invoked
ylog.Debug(">> ai request", "reqID", req.ReqID, "prompt", req.Prompt)
res, err := service.GetChatCompletions(req.Prompt, baseSystemMessage, reqID)
res, err := service.GetChatCompletions(req.Prompt, baseSystemMessage, reqID, req.IncludeCallStack)
if err != nil {
errCh <- err
} else {
Expand Down
20 changes: 11 additions & 9 deletions pkg/bridge/ai/provider/gemini/model_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,18 @@ func parseToolCallFromResponse(response *Response) []ai.ToolCall {
calls := make([]ai.ToolCall, 0)
for _, candidate := range response.Candidates {
fn := candidate.Content.Parts[0].FunctionCall
fd := &ai.FunctionDefinition{
Name: fn.Name,
Arguments: generateJSONSchemaArguments(fn.Args),
if fn != nil {
fd := &ai.FunctionDefinition{
Name: fn.Name,
Arguments: generateJSONSchemaArguments(fn.Args),
}
call := ai.ToolCall{
ID: "cc-gemini-id",
Type: "cc-function",
Function: fd,
}
calls = append(calls, call)
}
call := ai.ToolCall{
ID: "cc-gemini-id",
Type: "cc-function",
Function: fd,
}
calls = append(calls, call)
}
return calls
}
1 change: 1 addition & 0 deletions pkg/bridge/ai/provider/gemini/model_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type CandidateContent struct {

// Part is the element of CandidateContent
type Part struct {
Text string `json:"text,omitempty"`
FunctionCall *FunctionCall `json:"functionCall"`
}

Expand Down
8 changes: 8 additions & 0 deletions pkg/bridge/ai/provider/gemini/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ func (p *GeminiProvider) GetChatCompletions(userInstruction string, baseSystemMe
ylog.Debug("gemini api response", "calls", len(calls))

result := &ai.InvokeResponse{}
result.FinishReason = response.Candidates[0].FinishReason
result.Content = response.Candidates[0].Content.Parts[0].Text

if len(calls) == 0 {
return result, ai.ErrNoFunctionCall
}
Expand All @@ -123,6 +126,11 @@ func (p *GeminiProvider) GetChatCompletions(userInstruction string, baseSystemMe
if fd.Name == tc.Function.Name {
ylog.Debug("-----> add function", "name", fd.Name, "tag", tag)
currentCall := tc
fn := response.Candidates[0].Content.Parts[0].FunctionCall
if fn != nil {
args, _ := json.Marshal(fn.Args)
currentCall.Function.Arguments = string(args)
}
result.ToolCalls[tag] = append(result.ToolCalls[tag], &currentCall)
}
}
Expand Down
Loading

0 comments on commit 8b24008

Please sign in to comment.