Skip to content

Commit

Permalink
refactor(llm-bridge): improve log and trace (#938)
Browse files Browse the repository at this point in the history
1. Enhance llm bridge logging, add `service`, `namespace`, `traceId`,
`requestId`, `duration` field.
2. The Bridge's http handler supports the injection of tracer.
3. Yomo exports the global trace client.
4. Add new `ResponseWriter` struct, the `ResponseWriter` supports record
errors and TTFT timing.
  • Loading branch information
woorui authored Nov 23, 2024
1 parent 2221a74 commit 40a85f6
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 108 deletions.
2 changes: 1 addition & 1 deletion core/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ func NewServer(name string, opts ...ServerOption) *Server {
o(options)
}

logger := options.logger.With("component", "zipper", "zipper_name", name)
logger := options.logger.With("service", "zipper", "zipper_name", name)

ctx, ctxCancel := context.WithCancel(context.Background())

Expand Down
2 changes: 1 addition & 1 deletion pkg/bridge/ai/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func registerFunction(r register.Register) core.ConnMiddleware {
fd := ai.FunctionDefinition{}
err := json.Unmarshal([]byte(definition), &fd)
if err != nil {
conn.Logger.Error("unmarshal function definition", "error", err)
conn.Logger.Error("unmarshal function definition", "err", err)
return
}
err = r.RegisterFunction(tag, &fd, conn.ID(), connMd)
Expand Down
98 changes: 83 additions & 15 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
)

const (
Expand Down Expand Up @@ -82,11 +83,10 @@ func DecorateHandler(h http.Handler, decorates ...func(handler http.Handler) htt
func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider provider.LLMProvider, logger *slog.Logger) (*BasicAPIServer, error) {
zipperAddr = parseZipperAddr(zipperAddr)

logger = logger.With("component", "bridge")
logger = logger.With("service", "llm-bridge")

service := NewService(zipperAddr, provider, &ServiceOptions{
Logger: logger,
Tracer: otel.Tracer("yomo-llm-bridge"),
CredentialFunc: func(r *http.Request) (string, error) { return credential, nil },
})

Expand All @@ -104,11 +104,15 @@ func NewBasicAPIServer(config *Config, zipperAddr, credential string, provider p
// decorateReqContext decorates the context of the request, it injects a transID into the request's context,
// log the request information and start tracing the request.
func decorateReqContext(service *Service, logger *slog.Logger) func(handler http.Handler) http.Handler {
host, _ := os.Hostname()
hostname, _ := os.Hostname()
tracer := otel.Tracer("yomo-llm-bridge")

return func(handler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctx = WithTracerContext(ctx, tracer)

start := time.Now()

caller, err := service.LoadOrCreateCaller(r)
if err != nil {
Expand All @@ -118,20 +122,41 @@ func decorateReqContext(service *Service, logger *slog.Logger) func(handler http
ctx = WithCallerContext(ctx, caller)

// trace every request
ctx, span := service.option.Tracer.Start(
ctx, span := tracer.Start(
ctx,
r.URL.Path,
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(attribute.String("host", host)),
trace.WithAttributes(attribute.String("host", hostname)),
)
defer span.End()

transID := id.New(32)
ctx = WithTransIDContext(ctx, transID)

logger.Info("request", "method", r.Method, "path", r.URL.Path, "transID", transID)
ww := NewResponseWriter(w)

handler.ServeHTTP(ww, r.WithContext(ctx))

handler.ServeHTTP(w, r.WithContext(ctx))
duration := time.Since(start)
if !ww.TTFT.IsZero() {
duration = ww.TTFT.Sub(start)
}

logContent := []any{
"namespace", fmt.Sprintf("%s %s", r.Method, r.URL.Path),
"stream", ww.IsStream,
"host", hostname,
"requestId", transID,
"duration", duration,
}
if traceID := span.SpanContext().TraceID(); traceID.IsValid() {
logContent = append(logContent, "traceId", traceID.String())
}
if ww.Err != nil {
logger.Error("llm birdge request", append(logContent, "err", ww.Err)...)
} else {
logger.Info("llm birdge request", logContent...)
}
})
}
}
Expand All @@ -156,7 +181,6 @@ func (h *Handler) HandleOverview(w http.ResponseWriter, r *http.Request) {
functions[tag] = tc.Function
}

w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(&ai.OverviewResponse{Functions: functions})
}

Expand All @@ -167,26 +191,33 @@ func (h *Handler) HandleInvoke(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
transID = FromTransIDContext(ctx)
ww = w.(*ResponseWriter)
)
defer r.Body.Close()

req, err := DecodeRequest[ai.InvokeRequest](r, w, h.service.logger)
if err != nil {
ww.Err = errors.New("bad request")
return
}

ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

var (
caller = FromCallerContext(ctx)
tracer = FromTracerContext(ctx)
)

w.Header().Set("Content-Type", "application/json")

res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, FromCallerContext(ctx), req.IncludeCallStack)
res, err := h.service.GetInvoke(ctx, req.Prompt, baseSystemMessage, transID, caller, req.IncludeCallStack, tracer)
if err != nil {
ww.Err = err
RespondWithError(w, http.StatusInternalServerError, err, h.service.logger)
return
}

w.WriteHeader(http.StatusOK)
_ = json.NewEncoder(w).Encode(res)
}

Expand All @@ -195,18 +226,34 @@ func (h *Handler) HandleChatCompletions(w http.ResponseWriter, r *http.Request)
var (
ctx = r.Context()
transID = FromTransIDContext(ctx)
ww = w.(*ResponseWriter)
)
defer r.Body.Close()

req, err := DecodeRequest[openai.ChatCompletionRequest](r, w, h.service.logger)
if err != nil {
ww.Err = err
return
}

ctx, cancel := context.WithTimeout(r.Context(), RequestTimeout)
defer cancel()

if err := h.service.GetChatCompletions(ctx, req, transID, FromCallerContext(ctx), w); err != nil {
var (
caller = FromCallerContext(ctx)
tracer = FromTracerContext(ctx)
)

if err := h.service.GetChatCompletions(ctx, req, transID, caller, ww, tracer); err != nil {
ww.Err = err
if err == context.Canceled {
return
}
if ww.IsStream {
h.service.logger.Error("bridge server error", "err", err.Error(), "err_type", reflect.TypeOf(err).String())
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s"}}`, err.Error())))
return
}
RespondWithError(w, http.StatusBadRequest, err, h.service.logger)
return
}
Expand All @@ -227,6 +274,14 @@ func DecodeRequest[T any](r *http.Request, w http.ResponseWriter, logger *slog.L

// RespondWithError writes an error to response according to the OpenAI API spec.
func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.Logger) {
code, errString := parseCodeError(code, err)
logger.Error("bridge server error", "err", errString, "err_type", reflect.TypeOf(err).String())

w.WriteHeader(code)
w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString)))
}

func parseCodeError(code int, err error) (int, string) {
errString := err.Error()

switch e := err.(type) {
Expand All @@ -238,10 +293,7 @@ func RespondWithError(w http.ResponseWriter, code int, err error, logger *slog.L
errString = e.Error()
}

logger.Error("bridge server error", "err", errString, "err_type", reflect.TypeOf(err).String())

w.WriteHeader(code)
w.Write([]byte(fmt.Sprintf(`{"error":{"code":"%d","message":"%s"}}`, code, errString)))
return code, errString
}

func getLocalIP() (string, error) {
Expand Down Expand Up @@ -291,3 +343,19 @@ func FromTransIDContext(ctx context.Context) string {
}
return val
}

type tracerContextKey struct{}

// WithTracerContext adds the tracer to the request context
func WithTracerContext(ctx context.Context, tracer trace.Tracer) context.Context {
return context.WithValue(ctx, tracerContextKey{}, tracer)
}

// FromTransIDContext returns the transID from the request context
func FromTracerContext(ctx context.Context) trace.Tracer {
val, ok := ctx.Value(tracerContextKey{}).(trace.Tracer)
if !ok {
return new(noop.Tracer)
}
return val
}
12 changes: 8 additions & 4 deletions pkg/bridge/ai/provider/xai/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package xai
import (
"context"
"testing"
"time"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -32,23 +33,26 @@ func TestXAIProvider_GetChatCompletions(t *testing.T) {
Model: "groq-beta",
}

_, err := provider.GetChatCompletions(context.TODO(), req, nil)
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
defer cancel()

_, err := provider.GetChatCompletions(ctx, req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
_, err = provider.GetChatCompletionsStream(ctx, req, nil)
assert.Error(t, err)
t.Log(err)

req = openai.ChatCompletionRequest{
Messages: msgs,
}

_, err = provider.GetChatCompletions(context.TODO(), req, nil)
_, err = provider.GetChatCompletions(ctx, req, nil)
assert.Error(t, err)
t.Log(err)

_, err = provider.GetChatCompletionsStream(context.TODO(), req, nil)
_, err = provider.GetChatCompletionsStream(ctx, req, nil)
assert.Error(t, err)
t.Log(err)
}
86 changes: 86 additions & 0 deletions pkg/bridge/ai/response_writer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package ai

import (
"encoding/json"
"io"
"net/http"
"time"
)

// ResponseWriter is a wrapper for http.ResponseWriter.
// It is used to add TTFT and Err to the response.
type ResponseWriter struct {
IsStream bool
Err error
TTFT time.Time
underlying http.ResponseWriter
}

// NewResponseWriter returns a new ResponseWriter.
func NewResponseWriter(w http.ResponseWriter) *ResponseWriter {
return &ResponseWriter{
underlying: w,
}
}

// Header returns the headers of the underlying ResponseWriter.
func (w *ResponseWriter) Header() http.Header {
return w.underlying.Header()
}

// Write writes the data to the underlying ResponseWriter.
func (w *ResponseWriter) Write(b []byte) (int, error) {
return w.underlying.Write(b)
}

// WriteHeader writes the header to the underlying ResponseWriter.
func (w *ResponseWriter) WriteHeader(code int) {
w.underlying.WriteHeader(code)
}

// WriteStreamEvent writes the event to the underlying ResponseWriter.
func (w *ResponseWriter) WriteStreamEvent(event any) error {
if _, err := io.WriteString(w, "data: "); err != nil {
return err
}
if err := json.NewEncoder(w).Encode(event); err != nil {
return err
}
if _, err := io.WriteString(w, "\n"); err != nil {
return err
}
flusher, ok := w.underlying.(http.Flusher)
if ok {
flusher.Flush()
}
return nil
}

// WriteStreamDone writes the done event to the underlying ResponseWriter.
func (w *ResponseWriter) WriteStreamDone() error {
_, err := io.WriteString(w, "data: [DONE]")

flusher, ok := w.underlying.(http.Flusher)
if ok {
flusher.Flush()
}

return err
}

// SetStreamHeader sets the stream headers of the underlying ResponseWriter.
func (w *ResponseWriter) SetStreamHeader() http.Header {
h := w.Header()
h.Set("Content-Type", "text/event-stream")
h.Set("Cache-Control", "no-cache, must-revalidate")
h.Set("x-content-type-options", "nosniff")
return h
}

// Flush flushes the underlying ResponseWriter.
func (w *ResponseWriter) Flush() {
flusher, ok := w.underlying.(http.Flusher)
if ok {
flusher.Flush()
}
}
34 changes: 34 additions & 0 deletions pkg/bridge/ai/response_writer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package ai

import (
"net/http/httptest"
"testing"

"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"
)

func TestResponseWriter(t *testing.T) {
recorder := httptest.NewRecorder()

w := NewResponseWriter(recorder)

h := w.SetStreamHeader()

err := w.WriteStreamEvent(openai.ChatCompletionResponse{
ID: "chatcmpl-123",
})
assert.NoError(t, err)

err = w.WriteStreamDone()
assert.NoError(t, err)

got := recorder.Body.String()

want := `data: {"id":"chatcmpl-123","object":"","created":0,"model":"","choices":null,"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0,"prompt_tokens_details":null,"completion_tokens_details":null},"system_fingerprint":""}
data: [DONE]`

assert.Equal(t, want, got)
assert.Equal(t, recorder.Header(), h)
}
Loading

0 comments on commit 40a85f6

Please sign in to comment.