Skip to content

Commit

Permalink
Add "gateway benchmark-stream"(PRIME-655) (#1138)
Browse files Browse the repository at this point in the history
* Add gateway benchmark-stream endpoint to let us benchmark real completion endpoints on CG and SG instances
* Add `--use-special-header` and optional request CSV output to the "benchmark gateway" command
  • Loading branch information
vdavid authored Dec 20, 2024
1 parent 5812ab7 commit b2e0f1d
Show file tree
Hide file tree
Showing 3 changed files with 407 additions and 41 deletions.
1 change: 1 addition & 0 deletions cmd/src/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Usage:
The commands are:
benchmark runs benchmarks against Cody Gateway
benchmark-stream runs benchmarks against Cody Gateway code completion streaming endpoints
Use "src gateway [command] -h" for more information about a command.
Expand Down
164 changes: 123 additions & 41 deletions cmd/src/gateway_benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,14 @@ type Stats struct {
Total time.Duration
}

type requestResult struct {
duration time.Duration
traceID string // X-Trace header value
}

func init() {
usage := `
'src gateway benchmark' runs performance benchmarks against Cody Gateway endpoints.
'src gateway benchmark' runs performance benchmarks against Cody Gateway and Sourcegraph test endpoints.
Usage:
Expand All @@ -39,17 +44,20 @@ Examples:
$ src gateway benchmark --sgp <token>
$ src gateway benchmark --requests 50 --sgp <token>
$ src gateway benchmark --gateway http://localhost:9992 --sourcegraph http://localhost:3082 --sgp <token>
$ src gateway benchmark --requests 50 --csv results.csv --sgp <token>
$ src gateway benchmark --requests 50 --csv results.csv --request-csv requests.csv --sgp <token>
$ src gateway benchmark --gateway https://cody-gateway.sourcegraph.com --sourcegraph https://sourcegraph.com --sgp <token> --use-special-header
`

flagSet := flag.NewFlagSet("benchmark", flag.ExitOnError)

var (
requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint")
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
gatewayEndpoint = flagSet.String("gateway", "https://cody-gateway.sourcegraph.com", "Cody Gateway endpoint")
sgEndpoint = flagSet.String("sourcegraph", "https://sourcegraph.com", "Sourcegraph endpoint")
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
requestCount = flagSet.Int("requests", 1000, "Number of requests to make per endpoint")
csvOutput = flagSet.String("csv", "", "Export results to CSV file (provide filename)")
requestLevelCsvOutput = flagSet.String("request-csv", "", "Export request results to CSV file (provide filename)")
gatewayEndpoint = flagSet.String("gateway", "", "Cody Gateway endpoint")
sgEndpoint = flagSet.String("sourcegraph", "", "Sourcegraph endpoint")
sgpToken = flagSet.String("sgp", "", "Sourcegraph personal access token for the called instance")
useSpecialHeader = flagSet.Bool("use-special-header", false, "Use special header to test the gateway")
)

handler := func(args []string) error {
Expand All @@ -61,15 +69,23 @@ Examples:
return cmderrors.Usage("additional arguments not allowed")
}

if *useSpecialHeader {
fmt.Println("Using special header 'cody-core-gc-test'")
}

var (
httpClient = &http.Client{}
endpoints = map[string]any{} // Values: URL `string`s or `*webSocketClient`s
)
if *gatewayEndpoint != "" {
fmt.Println("Benchmarking Cody Gateway instance:", *gatewayEndpoint)
headers := http.Header{
"X-Sourcegraph-Should-Trace": []string{"true"},
}
endpoints["ws(s): gateway"] = &webSocketClient{
conn: nil,
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
conn: nil,
URL: strings.Replace(fmt.Sprint(*gatewayEndpoint, "/v2/websocket"), "http", "ws", 1),
reqHeaders: headers,
}
endpoints["http(s): gateway"] = fmt.Sprint(*gatewayEndpoint, "/v2/http")
} else {
Expand All @@ -80,12 +96,18 @@ Examples:
return cmderrors.Usage("must specify --sgp <Sourcegraph personal access token>")
}
fmt.Println("Benchmarking Sourcegraph instance:", *sgEndpoint)
headers := http.Header{
"Authorization": []string{"token " + *sgpToken},
"X-Sourcegraph-Should-Trace": []string{"true"},
}
if *useSpecialHeader {
headers.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&<vpw1&AK>")
}

endpoints["ws(s): sourcegraph"] = &webSocketClient{
conn: nil,
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
headers: http.Header{
"Authorization": []string{"token " + *sgpToken},
},
conn: nil,
URL: strings.Replace(fmt.Sprint(*sgEndpoint, "/.api/gateway/websocket"), "http", "ws", 1),
reqHeaders: headers,
}
endpoints["http(s): sourcegraph"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http")
endpoints["http(s): http-then-ws"] = fmt.Sprint(*sgEndpoint, "/.api/gateway/http-then-websocket")
Expand All @@ -95,29 +117,33 @@ Examples:

fmt.Printf("Starting benchmark with %d requests per endpoint...\n", *requestCount)

var results []endpointResult
var eResults []endpointResult
rResults := map[string][]requestResult{}
for name, clientOrURL := range endpoints {
durations := make([]time.Duration, 0, *requestCount)
rResults[name] = make([]requestResult, 0, *requestCount)
fmt.Printf("\nTesting %s...", name)

for i := 0; i < *requestCount; i++ {
if ws, ok := clientOrURL.(*webSocketClient); ok {
duration := benchmarkEndpointWebSocket(ws)
if duration > 0 {
durations = append(durations, duration)
result := benchmarkEndpointWebSocket(ws)
if result.duration > 0 {
durations = append(durations, result.duration)
rResults[name] = append(rResults[name], result)
}
} else if url, ok := clientOrURL.(string); ok {
duration := benchmarkEndpointHTTP(httpClient, url, *sgpToken)
if duration > 0 {
durations = append(durations, duration)
result := benchmarkEndpointHTTP(httpClient, url, *sgpToken, *useSpecialHeader)
if result.duration > 0 {
durations = append(durations, result.duration)
rResults[name] = append(rResults[name], result)
}
}
}
fmt.Println()

stats := calculateStats(durations)

results = append(results, endpointResult{
eResults = append(eResults, endpointResult{
name: name,
avg: stats.Avg,
median: stats.Median,
Expand All @@ -130,14 +156,20 @@ Examples:
})
}

printResults(results, requestCount)
printResults(eResults, requestCount)

if *csvOutput != "" {
if err := writeResultsToCSV(*csvOutput, results, requestCount); err != nil {
if err := writeResultsToCSV(*csvOutput, eResults, requestCount); err != nil {
return fmt.Errorf("failed to export CSV: %v", err)
}
fmt.Printf("\nResults exported to %s\n", *csvOutput)
}
if *requestLevelCsvOutput != "" {
if err := writeRequestResultsToCSV(*requestLevelCsvOutput, rResults); err != nil {
return fmt.Errorf("failed to export request-level CSV: %v", err)
}
fmt.Printf("\nRequest-level results exported to %s\n", *requestLevelCsvOutput)
}

return nil
}
Expand All @@ -158,9 +190,10 @@ Examples:
}

type webSocketClient struct {
conn *websocket.Conn
URL string
headers http.Header
conn *websocket.Conn
URL string
reqHeaders http.Header
respHeaders http.Header
}

func (c *webSocketClient) reconnect() error {
Expand All @@ -169,11 +202,13 @@ func (c *webSocketClient) reconnect() error {
}
fmt.Println("Connecting to WebSocket..", c.URL)
var err error
c.conn, _, err = websocket.DefaultDialer.Dial(c.URL, c.headers)
var resp *http.Response
c.conn, resp, err = websocket.DefaultDialer.Dial(c.URL, c.reqHeaders)
if err != nil {
c.conn = nil // retry again later
return fmt.Errorf("WebSocket dial(%s): %v", c.URL, err)
}
c.respHeaders = resp.Header
fmt.Println("Connected!")
return nil
}
Expand All @@ -190,19 +225,23 @@ type endpointResult struct {
successful int
}

func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Duration {
func benchmarkEndpointHTTP(client *http.Client, url, accessToken string, useSpecialHeader bool) requestResult {
start := time.Now()
req, err := http.NewRequest("POST", url, strings.NewReader("ping"))
if err != nil {
fmt.Printf("Error creating request: %v\n", err)
return 0
return requestResult{}
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "token "+accessToken)
req.Header.Set("X-Sourcegraph-Should-Trace", "true")
if useSpecialHeader {
req.Header.Set("cody-core-gc-test", "M2R{+6VI?1,M3n&<vpw1&AK>")
}
resp, err := client.Do(req)
if err != nil {
fmt.Printf("Error calling %s: %v\n", url, err)
return 0
return requestResult{}
}
defer func() {
err := resp.Body.Close()
Expand All @@ -212,27 +251,30 @@ func benchmarkEndpointHTTP(client *http.Client, url, accessToken string) time.Du
}()
if resp.StatusCode != http.StatusOK {
fmt.Printf("non-200 response: %v\n", resp.Status)
return 0
return requestResult{}
}
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Error reading response body: %v\n", err)
return 0
return requestResult{}
}
if string(body) != "pong" {
fmt.Printf("Expected 'pong' response, got: %q\n", string(body))
return 0
return requestResult{}
}

return time.Since(start)
return requestResult{
duration: time.Since(start),
traceID: resp.Header.Get("X-Trace"),
}
}

func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
func benchmarkEndpointWebSocket(client *webSocketClient) requestResult {
// Perform initial websocket connection, if needed.
if client.conn == nil {
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
return 0
return requestResult{}
}
}

Expand All @@ -244,7 +286,7 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
}
return 0
return requestResult{}
}
_, message, err := client.conn.ReadMessage()

Expand All @@ -253,16 +295,19 @@ func benchmarkEndpointWebSocket(client *webSocketClient) time.Duration {
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
}
return 0
return requestResult{}
}
if string(message) != "pong" {
fmt.Printf("Expected 'pong' response, got: %q\n", string(message))
if err := client.reconnect(); err != nil {
fmt.Printf("Error reconnecting: %v\n", err)
}
return 0
return requestResult{}
}
return requestResult{
duration: time.Since(start),
traceID: client.respHeaders.Get("Content-Type"),
}
return time.Since(start)
}

func calculateStats(durations []time.Duration) Stats {
Expand Down Expand Up @@ -438,3 +483,40 @@ func writeResultsToCSV(filename string, results []endpointResult, requestCount *

return nil
}

func writeRequestResultsToCSV(filename string, results map[string][]requestResult) error {
file, err := os.Create(filename)
if err != nil {
return fmt.Errorf("failed to create CSV file: %v", err)
}
defer func() {
err := file.Close()
if err != nil {
return
}
}()

writer := csv.NewWriter(file)
defer writer.Flush()

// Write header
header := []string{"Endpoint", "Duration (ms)", "Trace ID"}
if err := writer.Write(header); err != nil {
return fmt.Errorf("failed to write CSV header: %v", err)
}

for endpoint, requestResults := range results {
for _, result := range requestResults {
row := []string{
endpoint,
fmt.Sprintf("%.2f", float64(result.duration.Microseconds())/1000),
result.traceID,
}
if err := writer.Write(row); err != nil {
return fmt.Errorf("failed to write CSV row: %v", err)
}
}
}

return nil
}
Loading

0 comments on commit b2e0f1d

Please sign in to comment.