diff --git a/tests/extproc/testupstream_test.go b/tests/extproc/testupstream_test.go index fca83816..52fcdb70 100644 --- a/tests/extproc/testupstream_test.go +++ b/tests/extproc/testupstream_test.go @@ -4,6 +4,7 @@ package extproc import ( "encoding/base64" + "fmt" "io" "net/http" "os" @@ -11,6 +12,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/envoyproxy/ai-gateway/filterconfig" @@ -62,12 +64,17 @@ func TestWithTestUpstream(t *testing.T) { requestBody, // responseBody is the response body to return from the test upstream. responseBody, + // responseType is either empty, "sse" or "aws-event-stream" as implemented by the test upstream. + responseType, // expPath is the expected path to be sent to the test upstream. expPath string + // expRequestBody is the expected body to be sent to the test upstream. + // This can be used to test the request body translation. + expRequestBody string // expStatus is the expected status code from the gateway. expStatus int - // expBody is the expected body from the gateway. - expBody string + // expResponseBody is the expected body from the gateway to the client. + expResponseBody string }{ { name: "unknown path", @@ -80,25 +87,78 @@ func TestWithTestUpstream(t *testing.T) { expStatus: http.StatusInternalServerError, }, { - name: "aws - /v1/chat/completions", - backend: "aws-bedrock", - path: "/v1/chat/completions", - requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`, - expPath: "/model/something/converse", - responseBody: `{"output":{"message":{"content":[{"text":"response"},{"text":"from"},{"text":"assistant"}],"role":"assistant"}},"stopReason":null,"usage":{"inputTokens":10,"outputTokens":20,"totalTokens":30}}`, - expStatus: http.StatusOK, - expBody: `{"choices":[{"finish_reason":"stop","index":0,"logprobs":{},"message":{"content":"response","role":"assistant"}},{"finish_reason":"stop","index":1,"logprobs":{},"message":{"content":"from","role":"assistant"}},{"finish_reason":"stop","index":2,"logprobs":{},"message":{"content":"assistant","role":"assistant"}}],"object":"chat.completion","usage":{"completion_tokens":20,"prompt_tokens":10,"total_tokens":30}}`, + name: "aws - /v1/chat/completions", + backend: "aws-bedrock", + path: "/v1/chat/completions", + requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`, + expPath: "/model/something/converse", + responseBody: `{"output":{"message":{"content":[{"text":"response"},{"text":"from"},{"text":"assistant"}],"role":"assistant"}},"stopReason":null,"usage":{"inputTokens":10,"outputTokens":20,"totalTokens":30}}`, + expRequestBody: `{"inferenceConfig":{},"messages":[],"modelId":null,"system":[{"text":"You are a chatbot."}]}`, + expStatus: http.StatusOK, + expResponseBody: `{"choices":[{"finish_reason":"stop","index":0,"logprobs":{},"message":{"content":"response","role":"assistant"}},{"finish_reason":"stop","index":1,"logprobs":{},"message":{"content":"from","role":"assistant"}},{"finish_reason":"stop","index":2,"logprobs":{},"message":{"content":"assistant","role":"assistant"}}],"object":"chat.completion","usage":{"completion_tokens":20,"prompt_tokens":10,"total_tokens":30}}`, + }, + { + name: "openai - /v1/chat/completions", + backend: "openai", + path: "/v1/chat/completions", + method: http.MethodPost, + requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`, + expPath: "/v1/chat/completions", + responseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`, + expStatus: http.StatusOK, + expResponseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`, }, { - name: "openai - /v1/chat/completions", + name: "aws - /v1/chat/completions - streaming", + backend: "aws-bedrock", + path: "/v1/chat/completions", + responseType: "aws-event-stream", + method: http.MethodPost, + requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}], "stream": true}`, + expRequestBody: `{"inferenceConfig":{},"messages":[],"modelId":null,"system":[{"text":"You are a chatbot."}]}`, + expPath: "/model/something/converse-stream", + responseBody: `{"role":"assistant"} +{"delta":{"text":"Don"}} +{"delta":{"text":"'t worry, I'm here to help. It"}} +{"delta":{"text":" seems like you're testing my ability to respond appropriately"}} +{"stopReason":"end_turn"} +{"usage":{"inputTokens":41, "outputTokens":36, "totalTokens":77}} +`, + expStatus: http.StatusOK, + expResponseBody: `data: {"choices":[{"delta":{"content":"","role":"assistant"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":"Don"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":"'t worry, I'm here to help. It"}}],"object":"chat.completion.chunk"} + +data: {"choices":[{"delta":{"content":" seems like you're testing my ability to respond appropriately"}}],"object":"chat.completion.chunk"} + +data: {"object":"chat.completion.chunk","usage":{"completion_tokens":36,"prompt_tokens":41,"total_tokens":77}} + +data: [DONE] +`, + }, + { + name: "openai - /v1/chat/completions - streaming", backend: "openai", path: "/v1/chat/completions", + responseType: "sse", method: http.MethodPost, - requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}]}`, + requestBody: `{"model":"something","messages":[{"role":"system","content":"You are a chatbot."}], "stream": true}`, expPath: "/v1/chat/completions", - responseBody: `{"choices":[{"message":{"content":"This is a test."}}]}`, - expStatus: http.StatusOK, - expBody: `{"choices":[{"message":{"content":"This is a test."}}]}`, + responseBody: ` +{"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} +{"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":12,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} +[DONE] +`, + expStatus: http.StatusOK, + expResponseBody: `data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[{"index":0,"delta":{"role":"assistant","content":"","refusal":null},"logprobs":null,"finish_reason":null}],"usage":null} + +data: {"id":"chatcmpl-foo","object":"chat.completion.chunk","created":1731618222,"model":"gpt-4o-mini-2024-07-18","system_fingerprint":"fp_0ba0d124f1","choices":[],"usage":{"prompt_tokens":13,"completion_tokens":12,"total_tokens":25,"prompt_tokens_details":{"cached_tokens":0,"audio_tokens":0},"completion_tokens_details":{"reasoning_tokens":0,"audio_tokens":0,"accepted_prediction_tokens":0,"rejected_prediction_tokens":0}}} + +data: [DONE] + +`, }, } { t.Run(tc.name, func(t *testing.T) { @@ -108,6 +168,12 @@ func TestWithTestUpstream(t *testing.T) { req.Header.Set("x-test-backend", tc.backend) req.Header.Set("x-response-body", base64.StdEncoding.EncodeToString([]byte(tc.responseBody))) req.Header.Set("x-expected-path", base64.StdEncoding.EncodeToString([]byte(tc.expPath))) + if tc.responseType != "" { + req.Header.Set("x-response-type", tc.responseType) + } + if tc.expRequestBody != "" { + req.Header.Set("x-expected-request-body", base64.StdEncoding.EncodeToString([]byte(tc.expRequestBody))) + } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -120,11 +186,11 @@ func TestWithTestUpstream(t *testing.T) { t.Logf("unexpected status code: %d", resp.StatusCode) return false } - if tc.expBody != "" { + if tc.expResponseBody != "" { body, err := io.ReadAll(resp.Body) require.NoError(t, err) - if string(body) != tc.expBody { - t.Logf("unexpected response:\ngot: %s\nexp: %s", body, tc.expBody) + if string(body) != tc.expResponseBody { + fmt.Printf("unexpected response:\n%s", cmp.Diff(string(body), tc.expResponseBody)) return false } } diff --git a/tests/testupstream/main.go b/tests/testupstream/main.go index 9143998d..b5ffcd1c 100644 --- a/tests/testupstream/main.go +++ b/tests/testupstream/main.go @@ -20,6 +20,14 @@ import ( var logger = log.New(os.Stdout, "[testupstream] ", 0) const ( + // responseTypeKey is the key for the response type in the request. + // This can be either empty, "sse", or "aws-event-stream". + // * If this is "sse", the response body is expected to be a Server-Sent Event stream. + // Each line in x-response-body is treated as a separate [data] payload. + // * If this is "aws-event-stream", the response body is expected to be an AWS Event Stream. + // Each line in x-response-body is treated as a separate event payload. + // * If this is empty, the response body is expected to be a regular JSON response. + responseTypeKey = "x-response-type" // expectedHeadersKey is the key for the expected headers in the request. // The value is a base64 encoded string of comma separated key-value pairs. // E.g. "key1:value1,key2:value2". @@ -68,7 +76,7 @@ func main() { doMain(l) } -var streamingInterval = time.Second +var streamingInterval = 200 * time.Millisecond func doMain(l net.Listener) { if raw := os.Getenv("STREAMING_INTERVAL"); raw != "" { @@ -79,52 +87,11 @@ func doMain(l net.Listener) { defer l.Close() http.HandleFunc("/health", func(writer http.ResponseWriter, request *http.Request) { writer.WriteHeader(http.StatusOK) }) http.HandleFunc("/", handler) - http.HandleFunc("/sse", sseHandler) - http.HandleFunc("/aws-event-stream", awsEventStreamHandler) if err := http.Serve(l, nil); err != nil { // nolint: gosec logger.Printf("failed to serve: %v", err) } } -func sseHandler(w http.ResponseWriter, r *http.Request) { - expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey)) - if err != nil { - logger.Println("failed to decode the response body") - http.Error(w, "failed to decode the response body", http.StatusBadRequest) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("testupstream-id", os.Getenv("TESTUPSTREAM_ID")) - - for _, line := range bytes.Split(expResponseBody, []byte("\n")) { - line := string(line) - time.Sleep(streamingInterval) - - if _, err = w.Write([]byte("event: some event in testupstream\n")); err != nil { - logger.Println("failed to write the response body") - return - } - - if _, err = w.Write([]byte(fmt.Sprintf("data: %s\n\n", line))); err != nil { - logger.Println("failed to write the response body") - return - } - - if f, ok := w.(http.Flusher); ok { - f.Flush() - } else { - panic("expected http.ResponseWriter to be an http.Flusher") - } - logger.Println("response line sent:", line) - } - - logger.Println("response sent") - r.Context().Done() -} - func handler(w http.ResponseWriter, r *http.Request) { for k, v := range r.Header { logger.Printf("header %q: %s\n", k, v) @@ -205,17 +172,19 @@ func handler(w http.ResponseWriter, r *http.Request) { logger.Println("no expected testupstream-id") } - expectedPath, err := base64.StdEncoding.DecodeString(r.Header.Get(expectedPathHeaderKey)) - if err != nil { - logger.Println("failed to decode the expected path") - http.Error(w, "failed to decode the expected path", http.StatusBadRequest) - return - } + if expectedPath := r.Header.Get(expectedPathHeaderKey); expectedPath != "" { + expectedPath, err := base64.StdEncoding.DecodeString(expectedPath) + if err != nil { + logger.Println("failed to decode the expected path") + http.Error(w, "failed to decode the expected path", http.StatusBadRequest) + return + } - if r.URL.Path != string(expectedPath) { - logger.Printf("unexpected path: got %q, expected %q\n", r.URL.Path, string(expectedPath)) - http.Error(w, "unexpected path: got "+r.URL.Path+", expected "+string(expectedPath), http.StatusBadRequest) - return + if r.URL.Path != string(expectedPath) { + logger.Printf("unexpected path: got %q, expected %q\n", r.URL.Path, string(expectedPath)) + http.Error(w, "unexpected path: got "+r.URL.Path+", expected "+string(expectedPath), http.StatusBadRequest) + return + } } requestBody, err := io.ReadAll(r.Body) @@ -225,8 +194,8 @@ func handler(w http.ResponseWriter, r *http.Request) { return } - if r.Header.Get(expectedRequestBodyHeaderKey) != "" { - expectedBody, err := base64.StdEncoding.DecodeString(r.Header.Get(expectedRequestBodyHeaderKey)) + if expectedReqBody := r.Header.Get(expectedRequestBodyHeaderKey); expectedReqBody != "" { + expectedBody, err := base64.StdEncoding.DecodeString(expectedReqBody) if err != nil { logger.Println("failed to decode the expected request body") http.Error(w, "failed to decode the expected request body", http.StatusBadRequest) @@ -273,7 +242,6 @@ func handler(w http.ResponseWriter, r *http.Request) { } else { logger.Println("no response headers") } - w.Header().Set("Content-Type", "application/json") w.Header().Set("testupstream-id", os.Getenv("TESTUPSTREAM_ID")) status := http.StatusOK if v := r.Header.Get(responseStatusKey); v != "" { @@ -284,46 +252,84 @@ func handler(w http.ResponseWriter, r *http.Request) { return } } - w.WriteHeader(status) - if _, err := w.Write(responseBody); err != nil { - logger.Println("failed to write the response body") - } - logger.Println("response sent:", string(responseBody)) -} -func awsEventStreamHandler(w http.ResponseWriter, r *http.Request) { - expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey)) - if err != nil { - logger.Println("failed to decode the response body") - http.Error(w, "failed to decode the response body", http.StatusBadRequest) - return - } + switch r.Header.Get(responseTypeKey) { + case "sse": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(status) - w.Header().Set("Content-Type", "application/vnd.amazon.eventstream") - w.Header().Set("Transfer-Encoding", "chunked") - w.Header().Set("testupstream-id", os.Getenv("TESTUPSTREAM_ID")) + expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey)) + if err != nil { + logger.Println("failed to decode the response body") + http.Error(w, "failed to decode the response body", http.StatusBadRequest) + return + } + + for _, line := range bytes.Split(expResponseBody, []byte("\n")) { + line := string(line) + if line == "" { + continue + } + time.Sleep(streamingInterval) + + if _, err = w.Write([]byte(fmt.Sprintf("data: %s\n\n", line))); err != nil { + logger.Println("failed to write the response body") + return + } + + if f, ok := w.(http.Flusher); ok { + f.Flush() + } else { + panic("expected http.ResponseWriter to be an http.Flusher") + } + logger.Println("response line sent:", line) + } + logger.Println("response sent") + r.Context().Done() + case "aws-event-stream": + // w.Header().Set("Transfer-Encoding", "chunked") + w.Header().Set("Content-Type", "application/vnd.amazon.eventstream") + w.WriteHeader(status) + + expResponseBody, err := base64.StdEncoding.DecodeString(r.Header.Get(responseBodyHeaderKey)) + if err != nil { + logger.Println("failed to decode the response body") + http.Error(w, "failed to decode the response body", http.StatusBadRequest) + return + } + + e := eventstream.NewEncoder() + for _, line := range bytes.Split(expResponseBody, []byte("\n")) { + // Write each line as a chunk with AWS Event Stream format. + if len(line) == 0 { + continue + } + time.Sleep(streamingInterval) + if err := e.Encode(w, eventstream.Message{ + Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("content")}}, + Payload: line, + }); err != nil { + logger.Println("failed to encode the response body") + } + w.(http.Flusher).Flush() + logger.Println("response line sent:", string(line)) + } - e := eventstream.NewEncoder() - for _, line := range bytes.Split(expResponseBody, []byte("\n")) { - // Write each line as a chunk with AWS Event Stream format. - time.Sleep(streamingInterval) if err := e.Encode(w, eventstream.Message{ - Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("content")}}, - Payload: line, + Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("end")}}, + Payload: []byte("this-is-end"), }); err != nil { logger.Println("failed to encode the response body") } - w.(http.Flusher).Flush() - logger.Println("response line sent:", string(line)) - } - if err := e.Encode(w, eventstream.Message{ - Headers: eventstream.Headers{{Name: "event-type", Value: eventstream.StringValue("end")}}, - Payload: []byte("this-is-end"), - }); err != nil { - logger.Println("failed to encode the response body") + logger.Println("response sent") + r.Context().Done() + default: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + if _, err := w.Write(responseBody); err != nil { + logger.Println("failed to write the response body") + } + logger.Println("response sent:", string(responseBody)) } - - logger.Println("response sent") - r.Context().Done() } diff --git a/tests/testupstream/main_test.go b/tests/testupstream/main_test.go index 3ed66a03..c995ca60 100644 --- a/tests/testupstream/main_test.go +++ b/tests/testupstream/main_test.go @@ -30,6 +30,7 @@ func Test_main(t *testing.T) { t.Parallel() request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/sse", nil) require.NoError(t, err) + request.Header.Set(responseTypeKey, "sse") request.Header.Set(responseBodyHeaderKey, base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{"1", "2", "3", "4", "5"}, "\n")))) @@ -43,11 +44,6 @@ func Test_main(t *testing.T) { reader := bufio.NewReader(response.Body) for i := 0; i < 5; i++ { - eventLine, err := reader.ReadString('\n') - require.NoError(t, err) - require.NoError(t, err) - require.Equal(t, "event: some event in testupstream\n", eventLine) - dataLine, err := reader.ReadString('\n') require.NoError(t, err) require.Equal(t, fmt.Sprintf("data: %d\n", i+1), dataLine) @@ -182,8 +178,9 @@ func Test_main(t *testing.T) { t.Run("aws-event-stream", func(t *testing.T) { t.Parallel() - request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/aws-event-stream", nil) + request, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/", nil) require.NoError(t, err) + request.Header.Set(responseTypeKey, "aws-event-stream") request.Header.Set(responseBodyHeaderKey, base64.StdEncoding.EncodeToString([]byte(strings.Join([]string{"1", "2", "3", "4", "5"}, "\n"))))