Skip to content

Commit

Permalink
fix(cli): serverless ctx mock pkg move bug (#827)
Browse files Browse the repository at this point in the history
# Description

1. Move `MockContext` back to `serverless/mock` package.
2. Change function signature from `ReadLLMFunctionCall(fnCall any)
error` to `ReadLLMFunctionCall() (*ai.FunctionCall, error)`, This can
avoid type missing.
  • Loading branch information
woorui authored May 30, 2024
1 parent 5064d07 commit 965b8d6
Show file tree
Hide file tree
Showing 12 changed files with 209 additions and 99 deletions.
3 changes: 0 additions & 3 deletions ai/function_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ package ai

import (
"encoding/json"

"github.com/yomorun/yomo/serverless"
)

// ReducerTag is the observed tag of the reducer
Expand Down Expand Up @@ -32,7 +30,6 @@ type FunctionCall struct {
IsOK bool `json:"is_ok"`
// Error is the error message
Error string `json:"error,omitempty"`
ctx serverless.Context
}

// Bytes serialize the []byte of FunctionCallObject
Expand Down
65 changes: 3 additions & 62 deletions ai/function_call_test.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,11 @@
package ai

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

var jsonStr = "{\"req_id\":\"yYdzyl\",\"arguments\":\"{\\n \\\"sourceTimezone\\\": \\\"America/Los_Angeles\\\",\\n \\\"targetTimezone\\\": \\\"Asia/Singapore\\\",\\n \\\"timeString\\\": \\\"2024-03-25 07:00:00\\\"\\n}\",\"tool_call_id\":\"call_aZrtm5xcLs1qtP0SWo4CZi75\",\"function_name\":\"fn-timezone-converter\",\"is_ok\":false}"

var jsonStrWithResult = func(result string) string {
return fmt.Sprintf("{\"req_id\":\"yYdzyl\",\"result\":\"%s\",\"arguments\":\"{\\n \\\"sourceTimezone\\\": \\\"America/Los_Angeles\\\",\\n \\\"targetTimezone\\\": \\\"Asia/Singapore\\\",\\n \\\"timeString\\\": \\\"2024-03-25 07:00:00\\\"\\n}\",\"tool_call_id\":\"call_aZrtm5xcLs1qtP0SWo4CZi75\",\"function_name\":\"fn-timezone-converter\",\"is_ok\":true}", result)
}

var jsonStrWithError = func(err string) string {
return fmt.Sprintf("{\"req_id\":\"yYdzyl\",\"arguments\":\"{\\n \\\"sourceTimezone\\\": \\\"America/Los_Angeles\\\",\\n \\\"targetTimezone\\\": \\\"Asia/Singapore\\\",\\n \\\"timeString\\\": \\\"2024-03-25 07:00:00\\\"\\n}\",\"tool_call_id\":\"call_aZrtm5xcLs1qtP0SWo4CZi75\",\"function_name\":\"fn-timezone-converter\",\"is_ok\":true,\"error\":\"%s\"}", err)
}

var errJSONStr = "{a}"

var original = &FunctionCall{
ReqID: "yYdzyl",
Arguments: "{\n \"sourceTimezone\": \"America/Los_Angeles\",\n \"targetTimezone\": \"Asia/Singapore\",\n \"timeString\": \"2024-03-25 07:00:00\"\n}",
Expand All @@ -28,58 +15,12 @@ var original = &FunctionCall{
}

func TestFunctionCallBytes(t *testing.T) {
// Marshal the FunctionCall into bytes
bytes, err := original.Bytes()
// assert.NoError(t, err)

// // Unmarshal the bytes into a new FunctionCall
// target := &FunctionCall{}
// err = target.fromBytes(bytes)

assert.NoError(t, err)
assert.Equal(t, string(bytes), jsonStr, "Original and bytes should be equal")
}

func TestReadFunctionCall(t *testing.T) {
t.Run("ctx.Data is nil", func(t *testing.T) {
ctx := NewMockContext(nil, 0)
fnCall := &FunctionCall{}
err := ctx.ReadLLMFunctionCall(fnCall)
assert.Error(t, err)
})

t.Run("ctx.Data is invalid", func(t *testing.T) {
ctx := NewMockContext([]byte(errJSONStr), 0)
fnCall := &FunctionCall{}
err := ctx.ReadLLMFunctionCall(&fnCall)
assert.Error(t, err)
})
}

func TestReadLLMArguments(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)
target := make(map[string]string)
err := ctx.ReadLLMArguments(&target)
actual := &FunctionCall{}
err = actual.FromBytes(bytes)

assert.NoError(t, err)
assert.Equal(t, "America/Los_Angeles", target["sourceTimezone"])
assert.Equal(t, "Asia/Singapore", target["targetTimezone"])
assert.Equal(t, "2024-03-25 07:00:00", target["timeString"])
}

func TestWriteLLMResult(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)

// read
target := make(map[string]string)
err := ctx.ReadLLMArguments(&target)
assert.NoError(t, err)

// write
err = ctx.WriteLLMResult("test result")
assert.NoError(t, err)

res := ctx.RecordsWritten()
assert.Equal(t, ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithResult("test result"), string(res[0].Data))
assert.Equal(t, original, actual)
}
9 changes: 3 additions & 6 deletions ai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package ai
import (
"errors"
"fmt"
"strings"

"github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo/core/ylog"
Expand All @@ -30,12 +29,11 @@ func ConvertToInvokeResponse(res *openai.ChatCompletionResponse, tcs map[uint32]
ylog.Debug("++ llm result", "token_usage", fmt.Sprintf("%v", result.TokenUsage), "finish_reason", result.FinishReason)

// if llm said no function call, we should return the result
// gemini provider will return the finish_reason is "STOP", otherwise "stop"
if strings.ToLower(result.FinishReason) == "stop" {
if result.FinishReason == string(openai.FinishReasonStop) {
return result, nil
}

if result.FinishReason == "tool_calls" || result.FinishReason == "gemini_tool_calls" {
if result.FinishReason == "tool_calls" {
// assistant message
result.AssistantMessage = responseMessage
}
Expand All @@ -48,8 +46,7 @@ func ConvertToInvokeResponse(res *openai.ChatCompletionResponse, tcs map[uint32]
for _, call := range calls {
for tag, tc := range tcs {
ylog.Debug(">> compare tool call", "tag", tag, "tc", tc.Function.Name, "call", call.Function.Name)
// WARN: gemini process tool calls, currently function name not equal to tool call name, eg. "get-weather" != "get_weather"
if (tc.Function.Name == call.Function.Name && tc.Type == call.Type) || result.FinishReason == "gemini_tool_calls" {
if tc.Function.Name == call.Function.Name && tc.Type == call.Type {
if result.ToolCalls == nil {
result.ToolCalls = make(map[uint32][]*openai.ToolCall)
}
Expand Down
80 changes: 80 additions & 0 deletions ai/openai_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package ai

import (
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestConvertToInvokeResponse(t *testing.T) {
type args struct {
res *openai.ChatCompletionResponse
tcs map[uint32]openai.Tool
}
tests := []struct {
name string
args args
expected *InvokeResponse
}{
{
name: "tool_calls",
args: args{
res: &openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{
{
Index: 0,
Message: openai.ChatCompletionMessage{
Role: "user",
Content: "How is the weather today?",
ToolCalls: []openai.ToolCall{
{
Type: openai.ToolTypeFunction,
Function: openai.FunctionCall{Name: "get-weather"},
},
},
ToolCallID: "9TWd1eA2K3rmmtC21oER2a9F0YZif",
},
FinishReason: openai.FinishReasonToolCalls,
},
},
Usage: openai.Usage{
PromptTokens: 50,
CompletionTokens: 100,
TotalTokens: 150,
},
},
tcs: map[uint32]openai.Tool{
9: {
Type: openai.ToolTypeFunction,
Function: &openai.FunctionDefinition{Name: "get-weather"},
},
},
},
expected: &InvokeResponse{
Content: "How is the weather today?",
ToolCalls: map[uint32][]*openai.ToolCall{
9: {
{Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get-weather"}},
},
},
FinishReason: "tool_calls",
TokenUsage: TokenUsage{PromptTokens: 50, CompletionTokens: 100},
AssistantMessage: openai.ChatCompletionMessage{
Role: "user",
Content: "How is the weather today?",
ToolCalls: []openai.ToolCall{
{Type: openai.ToolTypeFunction, Function: openai.FunctionCall{Name: "get-weather"}},
},
ToolCallID: "9TWd1eA2K3rmmtC21oER2a9F0YZif",
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, _ := ConvertToInvokeResponse(tt.args.res, tt.args.tcs)
assert.Equal(t, tt.expected, actual)
})
}
}
16 changes: 9 additions & 7 deletions core/serverless/context_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ func (c *Context) WriteLLMResult(result string) error {
return c.Write(ai.ReducerTag, buf)
}

// ReadLLMFunctionCall reads LLM function call
func (c *Context) ReadLLMFunctionCall(fnCall any) error {
// LLMFunctionCall reads LLM function call
func (c *Context) LLMFunctionCall() (*ai.FunctionCall, error) {
if c.data == nil {
return errors.New("ctx.Data() is nil")
return nil, errors.New("ctx.Data() is nil")
}
fco, ok := fnCall.(*ai.FunctionCall)
if !ok {
return errors.New("given object is not *ai.FunctionCall")

fco := &ai.FunctionCall{}
if err := fco.FromBytes(c.data); err != nil {
return nil, errors.New("LLMFunctionCall: given object is not *ai.FunctionCall")
}
return fco.FromBytes(c.data)

return fco, nil
}
4 changes: 2 additions & 2 deletions example/10-ai/llm-sfn-get-weather/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ func Handler(ctx serverless.Context) {
if err == nil {
slog.Info("[sfn] >> write", "tag", ai.ReducerTag, "msg", data)
fnCall := &ai.FunctionCall{}
err = ctx.ReadLLMFunctionCall(fnCall)
err = ctx.LLMFunctionCall(fnCall)
if err != nil {
slog.Error("[sfn] ReadLLMFunctionCall error", "err", err)
slog.Error("[sfn] LLMFunctionCall error", "err", err)
return
}
slog.Info("[sfn] >> write", "tag", ai.ReducerTag, "fnCall", fnCall)
Expand Down
3 changes: 0 additions & 3 deletions pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ func ConnMiddleware(next core.ConnHandler) core.ConnHandler {
}
}

// MetadataKey tells that the function is an ai function.
const MetadataKey = "ai"

// Config is the configuration of AI bridge.
// The configuration looks like:
//
Expand Down
3 changes: 1 addition & 2 deletions pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,7 @@ func (s *Service) createReducer() (yomo.StreamFunction, error) {
sfn.SetHandler(func(ctx serverless.Context) {
buf := ctx.Data()
ylog.Debug("[sfn-reducer]", "tag", ai.ReducerTag, "data", string(buf))
invoke := &ai.FunctionCall{}
err := ctx.ReadLLMFunctionCall(invoke)
invoke, err := ctx.LLMFunctionCall()
if err != nil {
ylog.Error("[sfn-reducer] parse function calling invoke", "err", err.Error())
return
Expand Down
6 changes: 4 additions & 2 deletions serverless/context.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Package serverless defines serverless handler context
package serverless

import "github.com/yomorun/yomo/ai"

// Context sfn handler context
type Context interface {
// Data incoming data
Expand All @@ -19,8 +21,8 @@ type Context interface {
ReadLLMArguments(args any) error
// WriteLLMResult writes LLM function result
WriteLLMResult(result string) error
// ReadLLMFunctionCall reads LLM function call
ReadLLMFunctionCall(fnCall any) error
// LLMFunctionCall reads LLM function call
LLMFunctionCall() (*ai.FunctionCall, error)
}

// CronContext sfn corn handler context
Expand Down
4 changes: 3 additions & 1 deletion serverless/guest/context_llm.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package guest

import "github.com/yomorun/yomo/ai"

func (c *GuestContext) ReadLLMArguments(args any) error {
panic("not implemented")
}
Expand All @@ -8,6 +10,6 @@ func (c *GuestContext) WriteLLMResult(result string) error {
panic("not implemented")
}

func (c *GuestContext) ReadLLMFunctionCall(fnCall any) error {
func (c *GuestContext) LLMFunctionCall() (*ai.FunctionCall, error) {
panic("not implemented")
}
24 changes: 13 additions & 11 deletions ai/mock_context.go → serverless/mock/mock_context.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
package ai
package mock

import (
"encoding/json"
"errors"
"sync"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/serverless"
"github.com/yomorun/yomo/serverless/guest"
)
Expand All @@ -22,7 +23,7 @@ type WriteRecord struct {
type MockContext struct {
data []byte
tag uint32
fnCall *FunctionCall
fnCall *ai.FunctionCall

mu sync.Mutex
wrSlice []WriteRecord
Expand Down Expand Up @@ -86,7 +87,7 @@ func (c *MockContext) WriteWithTarget(tag uint32, data []byte, target string) er

// ReadLLMArguments reads LLM function arguments.
func (c *MockContext) ReadLLMArguments(args any) error {
fnCall := &FunctionCall{}
fnCall := &ai.FunctionCall{}
err := fnCall.FromBytes(c.data)
if err != nil {
return err
Expand Down Expand Up @@ -117,21 +118,22 @@ func (c *MockContext) WriteLLMResult(result string) error {

c.wrSlice = append(c.wrSlice, WriteRecord{
Data: buf,
Tag: ReducerTag,
Tag: ai.ReducerTag,
})
return nil
}

// ReadLLMFunctionCall reads LLM function call.
func (c *MockContext) ReadLLMFunctionCall(fnCall any) error {
// LLMFunctionCall reads LLM function call.
func (c *MockContext) LLMFunctionCall() (*ai.FunctionCall, error) {
if c.data == nil {
return errors.New("ctx.Data() is nil")
return nil, errors.New("ctx.Data() is nil")
}
fco, ok := fnCall.(*FunctionCall)
if !ok {
return errors.New("given object is not *ai.FunctionCall")

fco := &ai.FunctionCall{}
if err := fco.FromBytes(c.data); err != nil {
return nil, errors.New("given object is not *ai.FunctionCall")
}
return fco.FromBytes(c.data)
return fco, nil
}

// RecordsWritten returns the data records be written with `ctx.Write`.
Expand Down
Loading

0 comments on commit 965b8d6

Please sign in to comment.