Skip to content

Commit

Permalink
feat(ai-bridge): handle llm-sfn return nothing or multiple results re…
Browse files Browse the repository at this point in the history
…turned (#833)
  • Loading branch information
woorui authored Jun 20, 2024
1 parent cf78fea commit d255f49
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 6 deletions.
7 changes: 5 additions & 2 deletions core/serverless/context_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import (

// ReadLLMArguments reads LLM function arguments
func (c *Context) ReadLLMArguments(args any) error {
fnCall := &ai.FunctionCall{}
err := fnCall.FromBytes(c.data)
fnCall, err := c.LLMFunctionCall()
if err != nil {
return err
}
Expand All @@ -27,13 +26,17 @@ func (c *Context) WriteLLMResult(result string) error {
if c.fnCall == nil {
return errors.New("no function call, can't write result")
}
if c.fnCall.IsOK && c.fnCall.Result != "" {
return errors.New("LLM function can only be called once")
}
// function call
c.fnCall.IsOK = true
c.fnCall.Result = result
buf, err := c.fnCall.Bytes()
if err != nil {
return err
}
c.data = buf
return c.Write(ai.ReducerTag, buf)
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ func (s *Service) GetChatCompletions(ctx context.Context, req openai.ChatComplet
}
promptUsage = resp.Usage.PromptTokens
completionUsage = resp.Usage.CompletionTokens
totalUsage = resp.Usage.CompletionTokens
totalUsage = resp.Usage.TotalTokens

ylog.Debug(" #1 first call", "response", fmt.Sprintf("%+v", resp))
// it is a function call
Expand Down
10 changes: 7 additions & 3 deletions serverless/mock/mock_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/serverless"
"github.com/yomorun/yomo/serverless/guest"
)

var _ serverless.Context = (*MockContext)(nil)
Expand Down Expand Up @@ -55,7 +54,7 @@ func (c *MockContext) Metadata(_ string) (string, bool) {

// HTTP returns the HTTP interface.H
func (m *MockContext) HTTP() serverless.HTTP {
return &guest.GuestHTTP{}
panic("not implemented, to use `net/http` package")
}

// Write writes the data with the given tag.
Expand Down Expand Up @@ -106,8 +105,13 @@ func (c *MockContext) WriteLLMResult(result string) error {
defer c.mu.Unlock()

if c.fnCall == nil {
return errors.New("no function call, can't write result")
fnCall, err := c.LLMFunctionCall()
if err != nil {
return err
}
c.fnCall = fnCall
}

// function call
c.fnCall.IsOK = true
c.fnCall.Result = result
Expand Down
11 changes: 11 additions & 0 deletions serverless/mock/mock_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,14 @@ func TestWriteLLMResult(t *testing.T) {
assert.Equal(t, ai.ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithResult("test result"), string(res[0].Data))
}

func TestWriteLLMResultWithoutRead(t *testing.T) {
ctx := NewMockContext([]byte(jsonStr), 0x10)
// write
err := ctx.WriteLLMResult("test result")
assert.NoError(t, err)

res := ctx.RecordsWritten()
assert.Equal(t, ai.ReducerTag, res[0].Tag)
assert.Equal(t, jsonStrWithResult("test result"), string(res[0].Data))
}
14 changes: 14 additions & 0 deletions sfn.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package yomo
import (
"context"
"errors"
"log/slog"

"github.com/robfig/cron/v3"

Expand All @@ -12,6 +13,7 @@ import (
"github.com/yomorun/yomo/core/serverless"
"github.com/yomorun/yomo/pkg/id"
"github.com/yomorun/yomo/pkg/trace"
yserverless "github.com/yomorun/yomo/serverless"
"go.opentelemetry.io/otel/attribute"
)

Expand Down Expand Up @@ -251,6 +253,7 @@ func (s *streamFunction) onDataFrame(dataFrame *frame.DataFrame) {

serverlessCtx := serverless.NewContext(s.client, dataFrame.Tag, md, dataFrame.Payload)
s.fn(serverlessCtx)
checkLLMFunctionCall(s.client.Logger, serverlessCtx)
}(dataFrame)
} else if s.pfn != nil {
data := dataFrame.Payload
Expand All @@ -270,3 +273,14 @@ func (s *streamFunction) SetErrorHandler(fn func(err error)) {
func (s *streamFunction) Init(fn func() error) error {
return fn()
}

func checkLLMFunctionCall(logger *slog.Logger, serverlessCtx yserverless.Context) {
fc, err := serverlessCtx.LLMFunctionCall()
if err != nil {
// it's not a LLM function call ctx
return
}
if !fc.IsOK {
logger.Warn("The function is not returning anything, please check if `WriteLLMResult()` has been called")
}
}

0 comments on commit d255f49

Please sign in to comment.