Skip to content

Commit

Permalink
refactor: new caller from interfaces & new service from its option (#883
Browse files Browse the repository at this point in the history
)

# Description

1. A new `Service` struct has been introduced, which is responsible for
handling the main logic. The `Service` struct has two main methods:
`GetChatCompletions()` and `GetInvoke()` that handle the main logic.

2. The `CallerProvider` struct has been removed.

3. The `Service` struct has an additional method called
`LoadOrCreateCaller()`, which is responsible for managing the creation
and caching of the `Caller` instances.

4. The `NewCaller` function, which is used to create a new `Caller`
instance, has a new parameter signature: `func (source yomo.Source,
reducer yomo.StreamFunction, md metadata.M, callTimeout time.Duration)
*Caller`. This means that when creating a `Service` instance, you need
to inject the way to create the source, reducer, and how to exchange
metadata.
To facilitate this, a new struct called `ServiceOption` has been
defined. The `service.LoadOrCreateCaller()` method will call the
`ServiceOption.SourceBuilder()`, `ServiceOption.ReducerBuilder()`,
`ServiceOption.MetadataExchanger() ` and then use the returned values as
the parameters for the `NewCaller` function to create the `Caller`
instance.

the `ServiceOption` struct:
```go
// ServiceOptions is the option for creating service
type ServiceOptions struct {
	// Logger is the logger for the service
	Logger *slog.Logger
	// Tracer is the tracer for the service
	Tracer trace.Tracer
	// CredentialFunc is the function for getting the credential from the request
	CredentialFunc func(r *http.Request) (string, error)
	// CallerCacheSize is the size of the caller's cache
	CallerCacheSize int
	// CallerCacheTTL is the time to live of the callers cache
	CallerCacheTTL time.Duration
	// CallerCallTimeout is the timeout for awaiting the function response.
	CallerCallTimeout time.Duration
	// SourceBuilder should builds an unconnected source.
	SourceBuilder func(zipperAddr, credential string) yomo.Source
	// ReducerBuilder should builds an unconnected reducer.
	ReducerBuilder func(zipperAddr, credential string) yomo.StreamFunction
	// MetadataExchanger exchanges metadata from the credential.
	MetadataExchanger func(credential string) (metadata.M, error)
}
```

5. Besides,`ServiceOptions` also allows you to modify the default
logger, tracer, and the method to get the credential, among other
configuration parameters.
  • Loading branch information
woorui authored Aug 16, 2024
1 parent f438306 commit c68e252
Show file tree
Hide file tree
Showing 8 changed files with 1,288 additions and 1,228 deletions.
105 changes: 48 additions & 57 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ type BasicAPIServer struct {
zipperAddr string
credential string
httpHandler http.Handler
logger *slog.Logger
}

// Serve starts the Basic API Server
Expand All @@ -44,19 +43,20 @@ func Serve(config *Config, zipperListenAddr string, credential string, logger *s
if err != nil {
return err
}
srv, err := NewBasicAPIServer(config, zipperListenAddr, provider, credential, logger)
srv, err := NewBasicAPIServer(config, zipperListenAddr, credential, provider, logger)
if err != nil {
return err
}

logger.Info("start AI Bridge service", "addr", config.Server.Addr, "provider", provider.Name())
return srv.ServeAddr(config.Server.Addr)
return http.ListenAndServe(config.Server.Addr, srv.httpHandler)
}

func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handler) http.Handler) http.Handler {
// NewServeMux creates a new http.ServeMux for the llm bridge server.
func NewServeMux(service *Service) *http.ServeMux {
var (
h = &Handler{service}
mux = http.NewServeMux()
h = NewHandler(provider)
)
// GET /overview
mux.HandleFunc("/overview", h.HandleOverview)
Expand All @@ -65,57 +65,59 @@ func BridgeHTTPHanlder(provider provider.LLMProvider, decorater func(http.Handle
// POST /v1/chat/completions (OpenAI compatible interface)
mux.HandleFunc("/v1/chat/completions", h.HandleChatCompletions)

return decorater(mux)
return mux
}

// DecorateHandler decorates the http.Handler.
func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) http.Handler) http.Handler {
// decorate the http.Handler
for i := len(decorates) - 1; i >= 0; i-- {
h = decorates[i](h)
}
return h
}

// NewBasicAPIServer creates a new restful service
func NewBasicAPIServer(config *Config, zipperAddr string, provider provider.LLMProvider, credential string, logger *slog.Logger) (*BasicAPIServer, error) {
func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) {
zipperAddr = parseZipperAddr(zipperAddr)

cp := NewCallerProvider(zipperAddr, DefaultExchangeMetadataFunc)
logger = logger.With("component", "bridge")

service := NewService(zipperAddr, provider, &ServiceOptions{
Logger: logger,
Tracer: otel.Tracer("yomo-llm-bridge"),
CredentialFunc: func(r *http.Request) (string, error) { return credential, nil },
})

mux := NewServeMux(service)

server := &BasicAPIServer{
zipperAddr: zipperAddr,
credential: credential,
httpHandler: BridgeHTTPHanlder(provider, decorateReqContext(cp, logger, credential)),
logger: logger.With("component", "bridge"),
httpHandler: DecorateHandler(mux, decorateReqContext(service, logger)),
}

return server, nil
}

// ServeAddr starts a http server that provides some endpoints to bridge up the http server and YoMo.
// User can chat to the http server and interact with the YoMo's stream function.
func (a *BasicAPIServer) ServeAddr(addr string) error {
return http.ListenAndServe(addr, a.httpHandler)
}

// decorateReqContext decorates the context of the request, it injects a transID and a caller into the context.
func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential string) func(handler http.Handler) http.Handler {
tracer := otel.Tracer("yomo-llm-bridge")

caller, err := cp.Provide(credential)
if err != nil {
logger.Info("can't load caller", "err", err)

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusInternalServerError, err)
})
}
}

caller.SetTracer(tracer)

// decorateReqContext decorates the context of the request, it injects a transID into the request's context,
// log the request information and start tracing the request.
func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler {
host, _ := os.Hostname()

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

caller, err := service.LoadOrCreateCaller(r)
if err != nil {
RespondWithError(w, http.StatusBadRequest, err)
return
}
ctx = WithCallerContext(ctx, caller)

// trace every request
ctx, span := tracer.Start(
ctx, span := service.option.Tracer.Start(
ctx,
r.URL.Path,
trace.WithSpanKind(trace.SpanKindServer),
Expand All @@ -125,7 +127,6 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin

transID := id.New(32)
ctx = WithTransIDContext(ctx, transID)
ctx = WithCallerContext(ctx, caller)

logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID)

Expand All @@ -136,24 +137,16 @@ func decorateReqContext(cp CallerProvider, logger *slog.Logger, credential strin

// Handler handles the http request.
type Handler struct {
provider provider.LLMProvider
}

// NewHandler returns a new Handler.
func NewHandler(provider provider.LLMProvider) *Handler {
return &Handler{provider}
service *Service
}

// HandleOverview is the handler for GET /overview
func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {
caller := FromCallerContext(r.Context())

w.Header().Set("Content-Type", "application/json")

tcs, err := register.ListToolCalls(caller.Metadata())
tcs, err := register.ListToolCalls(FromCallerContext(r.Context()).Metadata())
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
json.NewEncoder(w).Encode(map[string]string{"error": err.Error()})
RespondWithError(w, http.StatusInternalServerError, err)
return
}

Expand All @@ -172,7 +165,6 @@ var baseSystemMessage = `You are a very helpful assistant. Your job is to choose
func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
caller = FromCallerContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
Expand All @@ -185,14 +177,14 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

res, err := GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, req.IncludeCallStack, caller, h.provider)
w.Header().Set("Content-Type", "application/json")

res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack)
if err != nil {
w.Header().Set("Content-Type", "application/json")
RespondWithError(w, http.StatusInternalServerError, err)
return
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(res)
}
Expand All @@ -201,7 +193,6 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
caller = FromCallerContext(ctx)
transID = FromTransIDContext(ctx)
)
defer r.Body.Close()
Expand All @@ -214,7 +205,7 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

if err := GetChatCompletions(ctx, req, transID, h.provider, caller, w); err != nil {
if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil {
RespondWithError(w, http.StatusBadRequest, err)
return
}
Expand Down Expand Up @@ -258,17 +249,17 @@ func getLocalIP() (string, error) {
type callerContextKey struct{}

// WithCallerContext adds the caller to the request context
func WithCallerContext(ctx context.Context, caller Caller) context.Context {
func WithCallerContext(ctx context.Context, caller *Caller) context.Context {
return context.WithValue(ctx, callerContextKey{}, caller)
}

// FromCallerContext returns the caller from the request context
func FromCallerContext(ctx context.Context) Caller {
service, ok := ctx.Value(callerContextKey{}).(Caller)
func FromCallerContext(ctx context.Context) *Caller {
caller, ok := ctx.Value(callerContextKey{}).(*Caller)
if !ok {
return nil
}
return service
return caller
}

type transIDContextKey struct{}
Expand Down
18 changes: 14 additions & 4 deletions pkg/bridge/ai/api_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"bytes"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/core/metadata"
"github.com/yomorun/yomo/pkg/bridge/ai/provider"
"github.com/yomorun/yomo/pkg/bridge/ai/register"
)
Expand Down Expand Up @@ -38,11 +40,19 @@ func TestServer(t *testing.T) {
t.Fatal(err)
}

cp := newMockCallerProvider()
flow := newMockDataFlow(newHandler(2 * time.Hour).handle)

cp.provideFunc = mockCallerProvideFunc(map[uint32][]mockFunctionCall{})
newCaller := func(_ yomo.Source, _ yomo.StreamFunction, _ metadata.M, _ time.Duration) (*Caller, error) {
return mockCaller(nil), err
}

service := newService("fake_zipper_addr", pd, newCaller, &ServiceOptions{
SourceBuilder: func(_, _ string) yomo.Source { return flow },
ReducerBuilder: func(_, _ string) yomo.StreamFunction { return flow },
MetadataExchanger: func(_ string) (metadata.M, error) { return metadata.M{"hello": "llm bridge"}, nil },
})

handler := BridgeHTTPHanlder(pd, decorateReqContext(cp, slog.Default(), ""))
handler := DecorateHandler(NewServeMux(service), decorateReqContext(service, service.logger))

// create a test server
server := httptest.NewServer(handler)
Expand Down
38 changes: 0 additions & 38 deletions pkg/bridge/ai/call_syncer.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@ import (
"time"

openai "github.com/sashabaranov/go-openai"
"github.com/yomorun/yomo"
"github.com/yomorun/yomo/ai"
"github.com/yomorun/yomo/serverless"
)

// CallSyncer fires a bunch of function callings, and wait the result of these function callings.
Expand Down Expand Up @@ -223,39 +221,3 @@ func (f *callSyncer) background() {
}
}
}

// ToReducer converts a stream function to a reducer that can reduce the function calling result.
func ToReducer(sfn yomo.StreamFunction, logger *slog.Logger, ch chan ReduceMessage) {
// set observe data tags
sfn.SetObserveDataTags(ai.ReducerTag)
// set reduce handler
sfn.SetHandler(func(ctx serverless.Context) {
invoke, err := ctx.LLMFunctionCall()
if err != nil {
ch <- ReduceMessage{ReqID: ""}
logger.Error("parse function calling invoke", "err", err.Error())
return
}
logger.Debug("sfn-reducer", "req_id", invoke.ReqID, "tool_call_id", invoke.ToolCallID, "result", string(invoke.Result))

message := openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleTool,
Content: invoke.Result,
ToolCallID: invoke.ToolCallID,
}

ch <- ReduceMessage{ReqID: invoke.ReqID, Message: message}
})
}

// ToSource convert a yomo source to the source that can send function calling body to the llm function.
func ToSource(source yomo.Source, logger *slog.Logger, ch chan TagFunctionCall) {
go func() {
for c := range ch {
buf, _ := c.FunctionCall.Bytes()
if err := source.Write(c.Tag, buf); err != nil {
logger.Error("send data to zipper", "err", err.Error())
}
}
}()
}
22 changes: 8 additions & 14 deletions pkg/bridge/ai/call_syncer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,10 @@ func TestTimeoutCallSyncer(t *testing.T) {
flow := newMockDataFlow(h.handle)
defer flow.Close()

reqs := make(chan TagFunctionCall)
ToSource(flow, slog.Default(), reqs)
req, _ := sourceWriteToChan(flow, slog.Default())
res, _ := reduceToChan(flow, slog.Default())

messages := make(chan ReduceMessage)
ToReducer(flow, slog.Default(), messages)

syncer := NewCallSyncer(slog.Default(), reqs, messages, time.Millisecond)
syncer := NewCallSyncer(slog.Default(), req, res, time.Millisecond)
go flow.run()

var (
Expand Down Expand Up @@ -61,13 +58,10 @@ func TestCallSyncer(t *testing.T) {
flow := newMockDataFlow(h.handle)
defer flow.Close()

reqs := make(chan TagFunctionCall)
ToSource(flow, slog.Default(), reqs)

messages := make(chan ReduceMessage)
ToReducer(flow, slog.Default(), messages)
req, _ := sourceWriteToChan(flow, slog.Default())
res, _ := reduceToChan(flow, slog.Default())

syncer := NewCallSyncer(slog.Default(), reqs, messages, 0)
syncer := NewCallSyncer(slog.Default(), req, res, 0)
go flow.run()

var (
Expand Down Expand Up @@ -118,7 +112,7 @@ func (h *handler) result() []openai.ChatCompletionMessage {
return want
}

// mockDataFlow mocks the data flow of ai bridge.
// mockDataFlow mocks the data flow of llm bridge.
// The data flow is: source -> hander -> reducer,
// It is `Write() -> handler() -> reducer()` in this mock implementation.
type mockDataFlow struct {
Expand Down Expand Up @@ -160,11 +154,11 @@ var _ yomo.StreamFunction = (*mockDataFlow)(nil)

// The test will not use blowing function in this mock implementation.
func (t *mockDataFlow) SetObserveDataTags(tag ...uint32) {}
func (t *mockDataFlow) Connect() error { return nil }
func (t *mockDataFlow) Init(fn func() error) error { panic("unimplemented") }
func (t *mockDataFlow) SetCronHandler(spec string, fn core.CronHandler) error { panic("unimplemented") }
func (t *mockDataFlow) SetPipeHandler(fn core.PipeHandler) error { panic("unimplemented") }
func (t *mockDataFlow) SetWantedTarget(string) { panic("unimplemented") }
func (t *mockDataFlow) Wait() { panic("unimplemented") }
func (t *mockDataFlow) Connect() error { panic("unimplemented") }
func (t *mockDataFlow) SetErrorHandler(fn func(err error)) { panic("unimplemented") }
func (t *mockDataFlow) WriteWithTarget(_ uint32, _ []byte, _ string) error { panic("unimplemented") }
Loading

0 comments on commit c68e252

Please sign in to comment.