diff --git a/.gitignore b/.gitignore index af76f79..6d54be0 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,3 @@ /deb/topograph/etc/topograph/ /deb/topograph/lib /rpmbuild -/pkg/protos diff --git a/Makefile b/Makefile index 5f8c7dd..8c441b8 100644 --- a/Makefile +++ b/Makefile @@ -23,7 +23,7 @@ GIT_REF =$(shell git rev-parse --abbrev-ref HEAD) IMAGE_TAG ?=$(GIT_REF) .PHONY: build -build: proto +build: @for target in $(TARGETS); do \ echo "Building $${target}"; \ CGO_ENABLED=0 go build -a -o $(OUTPUT_DIR)/$${target} \ diff --git a/go.mod b/go.mod index d9d32bd..b79abe7 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,6 @@ require ( github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.18 github.com/aws/aws-sdk-go-v2/service/ec2 v1.187.0 github.com/go-playground/validator/v10 v10.22.1 - github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/googleapis/gax-go/v2 v2.13.0 github.com/hashicorp/golang-lru v1.0.2 @@ -26,6 +25,7 @@ require ( golang.org/x/exp v0.0.0-20241009180824-f66d83c29e7c google.golang.org/api v0.204.0 google.golang.org/grpc v1.67.1 + google.golang.org/protobuf v1.35.1 gopkg.in/yaml.v3 v3.0.1 k8s.io/api v0.31.2 k8s.io/apimachinery v0.31.2 @@ -62,6 +62,7 @@ require ( github.com/gofrs/flock v0.12.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect @@ -96,7 +97,6 @@ require ( google.golang.org/genproto v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241104194629-dd2ea8efbc28 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20241104194629-dd2ea8efbc28 // indirect - google.golang.org/protobuf v1.35.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect k8s.io/kube-openapi v0.0.0-20241009091222-67ed5848f094 // indirect diff --git a/pkg/providers/providers_sim_test.go b/pkg/providers/providers_sim_test.go index 223fd7c..6faab68 100644 --- a/pkg/providers/providers_sim_test.go +++ b/pkg/providers/providers_sim_test.go @@ -23,7 +23,6 @@ import ( ) func TestGetSimParams(t *testing.T) { - testCases := []struct { name string params map[string]any diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index a69ea22..adb129f 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -40,6 +40,10 @@ type HttpServer struct { var srv *HttpServer func InitHttpServer(ctx context.Context, cfg *config.Config) { + srv = initHttpServer(ctx, cfg) +} + +func initHttpServer(ctx context.Context, cfg *config.Config) *HttpServer { mux := http.NewServeMux() mux.HandleFunc("/v1/generate", generate) @@ -47,7 +51,7 @@ func InitHttpServer(ctx context.Context, cfg *config.Config) { mux.HandleFunc("/healthz", healthz) mux.Handle("/metrics", promhttp.Handler()) - srv = &HttpServer{ + return &HttpServer{ ctx: ctx, cfg: cfg, srv: &http.Server{ diff --git a/pkg/server/http_server_test.go b/pkg/server/http_server_test.go new file mode 100644 index 0000000..0e3d561 --- /dev/null +++ b/pkg/server/http_server_test.go @@ -0,0 +1,152 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package server + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/url" + "testing" + "time" + + "github.com/NVIDIA/topograph/pkg/config" + "github.com/stretchr/testify/require" +) + +func getAvailablePort() (int, error) { + listener, err := net.Listen("tcp", ":0") + if err != nil { + return 0, err + } + defer listener.Close() + + return listener.Addr().(*net.TCPAddr).Port, nil +} + +func TestServer(t *testing.T) { + port, err := getAvailablePort() + require.NoError(t, err) + + cfg := &config.Config{ + HTTP: config.Endpoint{ + Port: port, + }, + RequestAggregationDelay: time.Second, + } + baseURL := fmt.Sprintf("http://localhost:%d", port) + + srv = initHttpServer(context.TODO(), cfg) + defer srv.Stop(nil) + go srv.Start() + + testCases := []struct { + name string + endpoint string + payload string + expected string + }{ + { + name: "Case 1: test healthz endpoint", + endpoint: "healthz", + expected: "OK\n", + }, + { + name: "Case 2: mock AWS request", + endpoint: "generate", + payload: ` +{ + "provider": { + "name": "aws-sim", + "params": { + "model_path": "../../tests/models/medium.yaml" + } + }, + "engine": { + "name": "test" + }, + "nodes": [ + { + "region": "R1", + "instances": { + "n11-1": "n11-1", + "n11-2": "n11-2", + "n12-1": "n12-1", + "n12-2": "n12-2", + "n13-1": "n13-1", + "n13-2": "n13-2", + "n14-1": "n14-1", + "n14-2": "n14-2" + } + } + ] +} +`, + expected: `SwitchName=sw3 Switches=sw[21-22] +SwitchName=sw21 Switches=sw[11-12] +SwitchName=sw22 Switches=sw[13-14] +SwitchName=sw11 Nodes=n11-[1-2] +SwitchName=sw12 Nodes=n12-[1-2] +SwitchName=sw13 Nodes=n13-[1-2] +SwitchName=sw14 Nodes=n14-[1-2] +`, + }, + } + + for _, tc := range testCases { + var resp *http.Response + var body []byte + switch tc.endpoint { + case "healthz": + resp, err = http.Get(baseURL + "/healthz") + case "generate": + // send topology request + resp, err = http.Post(baseURL+"/v1/generate", "application/json", bytes.NewBuffer([]byte(tc.payload))) + require.NoError(t, err) + require.Equal(t, http.StatusAccepted, resp.StatusCode) + + body, err = io.ReadAll(resp.Body) + require.NoError(t, err) + out := string(body) + fmt.Println("response", out) + resp.Body.Close() + + // wait for topology config generation + time.Sleep(3 * time.Second) + + // retrieve topology config + params := url.Values{} + params.Add("uid", out) + + fullURL := fmt.Sprintf("%s?%s", baseURL+"/v1/topology", params.Encode()) + resp, err = http.Get(fullURL) + + default: + t.Errorf("unsupported endpoint %s", tc.endpoint) + } + + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, tc.expected, string(body)) + } +}