Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support option for continuous monitoring token usage in streaming response #111

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions filterconfig/filterconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ type Config struct {
// LLMRequestCost configures the cost of each LLM-related request. Optional. If this is provided, the filter will populate
// the "calculated" cost in the filter metadata at the end of the response body processing.
LLMRequestCosts []LLMRequestCost `json:"llmRequestCosts,omitempty"`
// MonitorContinuousUsageStats flag controls if external process monitors every response-body chunk for usage stats
// when true, it will monitor for token metadata usage in every response-body chunk received during request in streaming mode
// compatible with vllm's 'continuous_usage_stats' flag
// when false, it will stop monitoring after detecting token metadata usage after finding it for the first time.
// compatible with OpenAI's streaming response (https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage)
// Only affects request in streaming mode
MonitorContinuousUsageStats bool `yaml:"monitorContinuousUsageStats,omitempty"`
Comment on lines +74 to +80
Copy link
Member

@mathetake mathetake Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you remove the change related to this? I think this is another issue and metadata is not cumulative so basically it's overriding previous ones if it's emitted in the middle.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be a property on the AIServiceBackend as only certain backend supports this, e.g vLLM service backend.

// InputSchema specifies the API schema of the input format of requests to the filter.
Schema VersionedAPISchema `json:"schema"`
// ModelNameHeaderKey is the header key to be populated with the model name by the filter.
Expand Down
1 change: 1 addition & 0 deletions filterconfig/filterconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ rules:
require.Equal(t, "OpenAI", string(cfg.Schema.Name))
require.Equal(t, "x-ai-eg-selected-backend", cfg.SelectedBackendHeaderKey)
require.Equal(t, "x-ai-eg-model", cfg.ModelNameHeaderKey)

require.Len(t, cfg.Rules, 2)
require.Equal(t, "llama3.3333", cfg.Rules[0].Headers[0].Value)
require.Equal(t, "gpt4.4444", cfg.Rules[1].Headers[0].Value)
Expand Down
1 change: 1 addition & 0 deletions internal/extproc/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ func (p *Processor) maybeBuildDynamicMetadata() (*structpb.Struct, error) {
if len(metadata) == 0 {
return nil, nil
}

return &structpb.Struct{
Fields: map[string]*structpb.Value{
p.config.metadataNamespace: {
Expand Down
2 changes: 1 addition & 1 deletion internal/extproc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (s *Server[P]) LoadConfig(config *filterconfig.Config) error {
for _, r := range config.Rules {
for _, b := range r.Backends {
if _, ok := factories[b.Schema]; !ok {
factories[b.Schema], err = translator.NewFactory(config.Schema, b.Schema)
factories[b.Schema], err = translator.NewFactory(config.Schema, b.Schema, config.MonitorContinuousUsageStats)
if err != nil {
return fmt.Errorf("cannot create translator factory: %w", err)
}
Expand Down
1 change: 0 additions & 1 deletion internal/extproc/translator/openai_awsbedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,6 @@ func (o *openAIToAWSBedrockTranslatorV1ChatCompletion) ResponseBody(respHeaders
if err := json.NewDecoder(body).Decode(&bedrockResp); err != nil {
return nil, nil, tokenUsage, fmt.Errorf("failed to unmarshal body: %w", err)
}

openAIResp := openai.ChatCompletionResponse{
Object: "chat.completion",
Choices: make([]openai.ChatCompletionResponseChoice, 0, len(bedrockResp.Output.Message.Content)),
Expand Down
29 changes: 18 additions & 11 deletions internal/extproc/translator/openai_openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@ import (
"github.com/envoyproxy/ai-gateway/internal/extproc/router"
)

// newOpenAIToOpenAITranslator implements [TranslatorFactory] for OpenAI to OpenAI translation.
func newOpenAIToOpenAITranslator(path string) (Translator, error) {
if path == "/v1/chat/completions" {
return &openAIToOpenAITranslatorV1ChatCompletion{}, nil
} else {
return nil, fmt.Errorf("unsupported path: %s", path)
// newOpenAIToOpenAITranslatorFactory implements [TranslatorFactory] for OpenAI to OpenAI translation.
func newOpenAIToOpenAITranslatorFactory(monitorContinuousUsageStats bool) Factory {
return func(path string) (Translator, error) {
if path == "/v1/chat/completions" {
return &openAIToOpenAITranslatorV1ChatCompletion{monitorContinuousUsageStats: monitorContinuousUsageStats}, nil
} else {
return nil, fmt.Errorf("unsupported path: %s", path)
}
}
}

// openAIToOpenAITranslatorV1ChatCompletion implements [Translator] for /v1/chat/completions.
type openAIToOpenAITranslatorV1ChatCompletion struct {
defaultTranslator
stream bool
buffered []byte
bufferingDone bool
stream bool
buffered []byte
bufferingDone bool
monitorContinuousUsageStats bool
}

// RequestBody implements [RequestBody].
Expand Down Expand Up @@ -96,8 +99,12 @@ func (o *openAIToOpenAITranslatorV1ChatCompletion) ResponseBody(respHeaders map[
}
}
if o.stream {
if !o.bufferingDone {
buf, err := io.ReadAll(body)
if !o.bufferingDone || o.monitorContinuousUsageStats {
// OpenAI's api suggests that usage info will only be sent in the last chunk (https://platform.openai.com/docs/api-reference/chat/streaming#chat/streaming-usage)
// whereas vllm model server supports including usage-info in each returned chunk.
// To incorporate both approaches, we check for usage-info in each chunk
var buf []byte
buf, err = io.ReadAll(body)
if err != nil {
return nil, nil, tokenUsage, fmt.Errorf("failed to read body: %w", err)
}
Expand Down
30 changes: 20 additions & 10 deletions internal/extproc/translator/openai_openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ import (

func TestNewOpenAIToOpenAITranslator(t *testing.T) {
t.Run("unsupported path", func(t *testing.T) {
_, err := newOpenAIToOpenAITranslator("/v1/foo/bar")
_, err := newOpenAIToOpenAITranslatorFactory(false)("/v1/foo/bar")
require.Error(t, err)
})
t.Run("/v1/chat/completions", func(t *testing.T) {
translator, err := newOpenAIToOpenAITranslator("/v1/chat/completions")
translator, err := newOpenAIToOpenAITranslatorFactory(false)("/v1/chat/completions")
require.NoError(t, err)
require.NotNil(t, translator)
})
Expand Down Expand Up @@ -184,6 +184,12 @@ data: [DONE]
if tokenUsage.OutputTokens > 0 {
require.Equal(t, uint32(12), tokenUsage.OutputTokens)
}
if tokenUsage.InputTokens > 0 {
require.Equal(t, uint32(13), tokenUsage.InputTokens)
}
if tokenUsage.TotalTokens > 0 {
require.Equal(t, uint32(25), tokenUsage.TotalTokens)
}
}
})
t.Run("non-streaming", func(t *testing.T) {
Expand All @@ -194,32 +200,36 @@ data: [DONE]
})
t.Run("valid body", func(t *testing.T) {
var resp openai.ChatCompletionResponse
resp.Usage.TotalTokens = 42
resp.Usage = openai.ChatCompletionResponseUsage{
PromptTokens: 11,
CompletionTokens: 22,
TotalTokens: 33,
}
body, err := json.Marshal(resp)
require.NoError(t, err)
o := &openAIToOpenAITranslatorV1ChatCompletion{}
_, _, usedToken, err := o.ResponseBody(nil, bytes.NewBuffer(body), false)
require.NoError(t, err)
require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken)
require.Equal(t, LLMTokenUsage{InputTokens: 11, OutputTokens: 22, TotalTokens: 33}, usedToken)
})
})
}

func TestExtractUsageFromBufferEvent(t *testing.T) {
t.Run("valid usage data", func(t *testing.T) {
o := &openAIToOpenAITranslatorV1ChatCompletion{}
o.buffered = []byte("data: {\"usage\": {\"total_tokens\": 42}}\n")
o.buffered = []byte("data: {\"usage\": {\"completion_tokens\":22,\"prompt_tokens\":11,\"total_tokens\": 33}}\n")
usedToken := o.extractUsageFromBufferEvent()
require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken)
require.Equal(t, LLMTokenUsage{TotalTokens: 33, InputTokens: 11, OutputTokens: 22}, usedToken)
require.True(t, o.bufferingDone)
require.Nil(t, o.buffered)
})

t.Run("valid usage data after invalid", func(t *testing.T) {
o := &openAIToOpenAITranslatorV1ChatCompletion{}
o.buffered = []byte("data: invalid\ndata: {\"usage\": {\"total_tokens\": 42}}\n")
o.buffered = []byte("data: invalid\ndata: {\"usage\": {\"completion_tokens\":22,\"prompt_tokens\":11,\"total_tokens\": 33}}\n")
usedToken := o.extractUsageFromBufferEvent()
require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken)
require.Equal(t, LLMTokenUsage{TotalTokens: 33, InputTokens: 11, OutputTokens: 22}, usedToken)
require.True(t, o.bufferingDone)
require.Nil(t, o.buffered)
})
Expand All @@ -232,9 +242,9 @@ func TestExtractUsageFromBufferEvent(t *testing.T) {
require.False(t, o.bufferingDone)
require.NotNil(t, o.buffered)

o.buffered = append(o.buffered, []byte("{\"usage\": {\"total_tokens\": 42}}\n")...)
o.buffered = append(o.buffered, []byte("{\"usage\": {\"completion_tokens\":22,\"prompt_tokens\":11,\"total_tokens\": 33}}\n")...)
usedToken = o.extractUsageFromBufferEvent()
require.Equal(t, LLMTokenUsage{TotalTokens: 42}, usedToken)
require.Equal(t, LLMTokenUsage{TotalTokens: 33, InputTokens: 11, OutputTokens: 22}, usedToken)
require.True(t, o.bufferingDone)
require.Nil(t, o.buffered)
})
Expand Down
4 changes: 2 additions & 2 deletions internal/extproc/translator/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ func isGoodStatusCode(code int) bool {
type Factory func(path string) (Translator, error)

// NewFactory returns a callback function that creates a translator for the given API schema combination.
func NewFactory(in, out filterconfig.VersionedAPISchema) (Factory, error) {
func NewFactory(in, out filterconfig.VersionedAPISchema, monitorContinuousUsageStats bool) (Factory, error) {
if in.Name == filterconfig.APISchemaOpenAI {
// TODO: currently, we ignore the LLMAPISchema."Version" field.
switch out.Name {
case filterconfig.APISchemaOpenAI:
return newOpenAIToOpenAITranslator, nil
return newOpenAIToOpenAITranslatorFactory(monitorContinuousUsageStats), nil
case filterconfig.APISchemaAWSBedrock:
return newOpenAIToAWSBedrockTranslator, nil
}
Expand Down
3 changes: 3 additions & 0 deletions internal/extproc/translator/translator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@ func TestNewFactory(t *testing.T) {
_, err := NewFactory(
filterconfig.VersionedAPISchema{Name: "Foo", Version: "v100"},
filterconfig.VersionedAPISchema{Name: "Bar", Version: "v123"},
false,
)
require.ErrorContains(t, err, "unsupported API schema combination: client={Foo v100}, backend={Bar v123}")
})
t.Run("openai to openai", func(t *testing.T) {
f, err := NewFactory(
filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaOpenAI},
filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaOpenAI},
false,
)
require.NoError(t, err)
require.NotNil(t, f)
Expand All @@ -34,6 +36,7 @@ func TestNewFactory(t *testing.T) {
f, err := NewFactory(
filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaOpenAI},
filterconfig.VersionedAPISchema{Name: filterconfig.APISchemaAWSBedrock},
false,
)
require.NoError(t, err)
require.NotNil(t, f)
Expand Down
Loading