Skip to content

Commit

Permalink
refactor: handshake with ai function definition (#798)
Browse files Browse the repository at this point in the history
  • Loading branch information
woorui authored May 6, 2024
1 parent 3b5f553 commit db3bd5e
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 280 deletions.
3 changes: 3 additions & 0 deletions ai/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,6 @@ type ChainMessage struct {
// ToolMessages is the tool messages aggragated from reducer-sfn by AI service
ToolMessages []ToolMessage
}

// FunctionDefinitionKey is the yomo metadata key for function definition
const FunctionDefinitionKey = "function-definition"
47 changes: 19 additions & 28 deletions core/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ func (c *Client) connect(ctx context.Context, addr string) (frame.Conn, error) {
WantedTarget: c.wantedTarget,
}

err = c.handshakeWithDefinition(hf)
if err != nil {
return nil, err
}

if err := conn.WriteFrame(hf); err != nil {
return conn, err
}
Expand All @@ -210,10 +215,6 @@ func (c *Client) connect(ctx context.Context, addr string) (frame.Conn, error) {

switch received.Type() {
case frame.TypeHandshakeAckFrame:
// check function calling definition
if err := c.writeAIRegisterFunctionFrame(conn, received.(*frame.HandshakeAckFrame)); err != nil {
return nil, err
}
return conn, nil
case frame.TypeRejectedFrame:
err := &ErrRejected{Message: received.(*frame.RejectedFrame).Message}
Expand All @@ -233,29 +234,21 @@ func (c *Client) connect(ctx context.Context, addr string) (frame.Conn, error) {
return nil, err
}

func (c *Client) writeAIRegisterFunctionFrame(conn *yquic.FrameConn, _ *frame.HandshakeAckFrame) error {
// register ai function
if c.clientType == ClientTypeStreamFunction {
functionDefinition, err := parseAIFunctionDefinition(c.name, c.opts.aiFunctionDescription, c.opts.aiFunctionInputModel)
if err != nil {
c.Logger.Error("parse ai function definition error", "err", err)
return err
}
// not exist ai function definition
if functionDefinition == nil {
return nil
}
for _, tag := range c.opts.observeDataTags {
registerFunctionFrame := &frame.AIRegisterFunctionFrame{
Name: c.name,
Tag: tag,
Definition: functionDefinition,
}
if err := conn.WriteFrame(registerFunctionFrame); err != nil {
return err
}
}
func (c *Client) handshakeWithDefinition(hf *frame.HandshakeFrame) error {
if c.clientType != ClientTypeStreamFunction {
return nil
}
// register ai function definition
functionDefinition, err := parseAIFunctionDefinition(c.name, c.opts.aiFunctionDescription, c.opts.aiFunctionInputModel)
if err != nil {
c.Logger.Error("parse ai function definition error", "err", err)
return err
}
// ai function definition is not be found
if functionDefinition == nil {
return nil
}
hf.FunctionDefinition = functionDefinition
return nil
}

Expand Down Expand Up @@ -413,8 +406,6 @@ func (c *Client) handleFrame(f frame.Frame) {
_ = c.Close()
case *frame.DataFrame:
c.processor(ff)
case *frame.AIRegisterFunctionAckFrame:
c.Logger.Info("register ai function success", "name", ff.Name, "tag", ff.Tag)
default:
c.Logger.Warn("received unexpected frame", "frame_type", f.Type().String())
}
Expand Down
66 changes: 20 additions & 46 deletions core/frame/frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import (
// 4. RejectedFrame
// 5. GoawayFrame
// 6. ConnectToFrame
// 7. AIRegisterFunctionFrame
// 8. AIRegisterFunctionAckFrame
//
// Read frame comments to understand the role of the frame.
type Frame interface {
Expand Down Expand Up @@ -62,6 +60,8 @@ type HandshakeFrame struct {
AuthPayload string
// Version is used by the source/sfn to communicate their spec version to the server.
Version string
// FunctionDefinition is the definition of the AI function.
FunctionDefinition []byte
// WantedTarget represents the target that accepts the data frames that carrying the same target.
WantedTarget string
}
Expand Down Expand Up @@ -103,46 +103,22 @@ type ConnectToFrame struct {
// Type returns the type of ConnectToFrame.
func (f *ConnectToFrame) Type() Type { return TypeConnectToFrame }

// AIRegisterFunctionFrame is used to register AI function.
type AIRegisterFunctionFrame struct {
Name string // Name is the name of the AI function.
Tag uint32
Definition []byte // Definition is the definition of the AI function.
}

// Type returns the type of AIRegisterFunctionFrame.
func (f *AIRegisterFunctionFrame) Type() Type { return TypeAIRegisterFunctionFrame }

// AIRegisterFunctionAckFrame is used to ack AIRegisterFunctionFrame.
type AIRegisterFunctionAckFrame struct {
Name string
Tag uint32
}

// Type returns the type of AIRegisterFunctionAckFrame.
func (f *AIRegisterFunctionAckFrame) Type() Type { return TypeAIRegisterFunctionAckFrame }

const (
TypeDataFrame Type = 0x3F // TypeDataFrame is the type of DataFrame.
TypeHandshakeFrame Type = 0x31 // TypeHandshakeFrame is the type of HandshakeFrame.
TypeHandshakeAckFrame Type = 0x29 // TypeHandshakeAckFrame is the type of HandshakeAckFrame.
TypeRejectedFrame Type = 0x39 // TypeRejectedFrame is the type of RejectedFrame.
TypeGoawayFrame Type = 0x2E // TypeGoawayFrame is the type of GoawayFrame.
TypeConnectToFrame Type = 0x3E // TypeConnectToFrame is the type of ConnectToFrame.
TypeAIRegisterFunctionFrame Type = 0x10 // TypeAIRegisterFunctionFrame is the type of AIRegisterFunctionFrame.
TypeAIRegisterFunctionAckFrame Type = 0x11 // TypeAIRegisterFunctionAckFrame is the type of AIRegisterFunctionAckFrame.

TypeDataFrame Type = 0x3F // TypeDataFrame is the type of DataFrame.
TypeHandshakeFrame Type = 0x31 // TypeHandshakeFrame is the type of HandshakeFrame.
TypeHandshakeAckFrame Type = 0x29 // TypeHandshakeAckFrame is the type of HandshakeAckFrame.
TypeRejectedFrame Type = 0x39 // TypeRejectedFrame is the type of RejectedFrame.
TypeGoawayFrame Type = 0x2E // TypeGoawayFrame is the type of GoawayFrame.
TypeConnectToFrame Type = 0x3E // TypeConnectToFrame is the type of ConnectToFrame.
)

var frameTypeStringMap = map[Type]string{
TypeDataFrame: "DataFrame",
TypeHandshakeFrame: "HandshakeFrame",
TypeHandshakeAckFrame: "HandshakeAckFrame",
TypeRejectedFrame: "RejectedFrame",
TypeGoawayFrame: "GoawayFrame",
TypeConnectToFrame: "ConnectToFrame",
TypeAIRegisterFunctionFrame: "AIRegisterFunctionFrame",
TypeAIRegisterFunctionAckFrame: "AIRegisterFunctionAckFrame",
TypeDataFrame: "DataFrame",
TypeHandshakeFrame: "HandshakeFrame",
TypeHandshakeAckFrame: "HandshakeAckFrame",
TypeRejectedFrame: "RejectedFrame",
TypeGoawayFrame: "GoawayFrame",
TypeConnectToFrame: "ConnectToFrame",
}

// String returns a human-readable string which represents the frame type.
Expand All @@ -156,14 +132,12 @@ func (f Type) String() string {
}

var frameTypeNewFuncMap = map[Type]func() Frame{
TypeDataFrame: func() Frame { return new(DataFrame) },
TypeHandshakeFrame: func() Frame { return new(HandshakeFrame) },
TypeHandshakeAckFrame: func() Frame { return new(HandshakeAckFrame) },
TypeRejectedFrame: func() Frame { return new(RejectedFrame) },
TypeGoawayFrame: func() Frame { return new(GoawayFrame) },
TypeConnectToFrame: func() Frame { return new(ConnectToFrame) },
TypeAIRegisterFunctionFrame: func() Frame { return new(AIRegisterFunctionFrame) },
TypeAIRegisterFunctionAckFrame: func() Frame { return new(AIRegisterFunctionAckFrame) },
TypeDataFrame: func() Frame { return new(DataFrame) },
TypeHandshakeFrame: func() Frame { return new(HandshakeFrame) },
TypeHandshakeAckFrame: func() Frame { return new(HandshakeAckFrame) },
TypeRejectedFrame: func() Frame { return new(RejectedFrame) },
TypeGoawayFrame: func() Frame { return new(GoawayFrame) },
TypeConnectToFrame: func() Frame { return new(ConnectToFrame) },
}

// NewFrame creates a new frame from Type.
Expand Down
8 changes: 7 additions & 1 deletion core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"sync"
"sync/atomic"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/auth"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/metadata"
Expand Down Expand Up @@ -230,7 +231,12 @@ func (s *Server) handshake(fconn frame.Conn) (*Connection, error) {
return nil, rejectHandshake(fconn, err)
}

// 4. add route rules
// 4. store function definition to metadata
if hf.FunctionDefinition != nil {
conn.Metadata().Set(ai.FunctionDefinitionKey, string(hf.FunctionDefinition))
}

// 5. add route rules
if err := s.addSfnRouteRule(conn.ID(), hf, conn.Metadata()); err != nil {
return nil, rejectHandshake(fconn, err)
}
Expand Down
2 changes: 1 addition & 1 deletion example/10-ai/llm-sfn-get-ip-and-latency/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ require (
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/quic-go/quic-go v0.42.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sashabaranov/go-openai v1.21.0 // indirect
github.com/sashabaranov/go-openai v1.23.0 // indirect
github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
Expand Down
4 changes: 2 additions & 2 deletions example/10-ai/llm-sfn-get-ip-and-latency/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzG
github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ=
github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog=
github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g=
github.com/sashabaranov/go-openai v1.21.0 h1:isAf3zPSD3VLd0pC2/2Q6ZyRK7jzPAaz+X3rjsviaYQ=
github.com/sashabaranov/go-openai v1.21.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.23.0 h1:KYW97r5yc35PI2MxeLZ3OofecB/6H+yxvSNqiT9u8is=
github.com/sashabaranov/go-openai v1.23.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sergi/go-diff v1.0.0/go.mod h1:0CfEIISq7TuYL3j771MWULgwwjU+GofnZX9QAmXWZgo=
github.com/shurcooL/component v0.0.0-20170202220835-f88ec8f54cc4/go.mod h1:XhFIlyj5a1fBNx5aJTbKoIq0mNaPvOagO+HjB3EtxrY=
github.com/shurcooL/events v0.0.0-20181021180414-410e4ca65f48/go.mod h1:5u70Mqkb5O5cxEA8nxTsgrgLehJeAw6Oc4Ab1c/P1HM=
Expand Down
40 changes: 16 additions & 24 deletions pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ package ai
import (
"encoding/json"
"errors"
"fmt"
"net"

"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/ylog"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
"gopkg.in/yaml.v3"
Expand All @@ -28,44 +26,38 @@ var (
func ConnMiddleware(next core.ConnHandler) core.ConnHandler {
return func(conn *core.Connection) {
connMd := conn.Metadata().Clone()
definition, ok := connMd.Get(ai.FunctionDefinitionKey)

defer func() {
// definition does not be transmitted in mesh network, It only works for handshake.
conn.Metadata().Set(ai.FunctionDefinitionKey, "")
next(conn)
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
if ok {
register.UnregisterFunction(conn.ID(), connMd)
conn.Logger.Info("unregister ai function", "name", conn.Name(), "connID", conn.ID())
}
}()

// check sfn type and is ai function
if conn.ClientType() != core.ClientTypeStreamFunction {
if conn.ClientType() != core.ClientTypeStreamFunction || !ok {
return
}
f, err := conn.FrameConn().ReadFrame()
// unregister ai function on any error
if err != nil {
conn.Logger.Error("failed to read frame on ai middleware", "err", err, "type", fmt.Sprintf("%T", err))
conn.Logger.Info("error type", "type", fmt.Sprintf("%T", err))
return
}
if ff, ok := f.(*frame.AIRegisterFunctionFrame); ok {
err := conn.FrameConn().WriteFrame(&frame.AIRegisterFunctionAckFrame{Name: ff.Name, Tag: ff.Tag})
if err != nil {
conn.Logger.Error("failed to write ai RegisterFunctionAckFrame", "name", ff.Name, "tag", ff.Tag, "err", err)
return
}

for _, tag := range conn.ObserveDataTags() {
// register ai function
fd := ai.FunctionDefinition{}
err = json.Unmarshal(ff.Definition, &fd)
err := json.Unmarshal([]byte(definition), &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
return
}
err = register.RegisterFunction(ff.Tag, &fd, conn.ID(), connMd)
err = register.RegisterFunction(tag, &fd, conn.ID(), connMd)
if err != nil {
conn.Logger.Error("failed to register ai function", "name", ff.Name, "tag", ff.Tag, "err", err)
conn.Logger.Error("failed to register ai function", "name", conn.Name(), "tag", tag, "err", err)
return
}
conn.Metadata().Set(MetadataKey, "1")
conn.Logger.Info("register ai function success", "name", ff.Name, "tag", ff.Tag, "definition", string(ff.Definition))
conn.Logger.Info("register ai function success", "name", conn.Name(), "tag", tag, "definition", string(definition))
}

}
}

Expand Down
53 changes: 0 additions & 53 deletions pkg/frame-codec/y3codec/ai_register_function_ack_frame.go

This file was deleted.

Loading

0 comments on commit db3bd5e

Please sign in to comment.