From f3eddb10aa3d0c9b838586a2b0f0501defc1d57e Mon Sep 17 00:00:00 2001 From: Dmitry Shmulevich Date: Fri, 18 Oct 2024 14:43:02 -0700 Subject: [PATCH] replace URL query with URL payload Signed-off-by: Dmitry Shmulevich --- pkg/common/const.go | 2 - pkg/common/types.go | 93 ++++++++++---------------- pkg/common/types_test.go | 88 ++++++++++++------------ pkg/config/config.go | 3 +- pkg/config/config_test.go | 23 ++++++- pkg/factory/provider.go | 2 +- pkg/node_observer/controller.go | 34 ++++++---- pkg/node_observer/node_informer.go | 9 ++- pkg/providers/aws/imds.go | 36 +++++----- pkg/providers/aws/instance_topology.go | 4 +- pkg/providers/aws/provider.go | 29 +++++--- pkg/providers/cw/provider.go | 2 +- pkg/providers/gcp/provider.go | 2 +- pkg/providers/oci/provider.go | 26 +++++-- pkg/server/engine.go | 22 +++--- pkg/server/grpc_client.go | 4 +- pkg/server/http_server.go | 63 ++++------------- pkg/utils/http.go | 14 ++-- 18 files changed, 229 insertions(+), 227 deletions(-) diff --git a/pkg/common/const.go b/pkg/common/const.go index 1ad3275..4e54367 100644 --- a/pkg/common/const.go +++ b/pkg/common/const.go @@ -28,8 +28,6 @@ const ( EngineTest = "test" KeyUID = "uid" - KeyProvider = "provider" - KeyEngine = "engine" KeyTopoConfigPath = "topology_config_path" KeyTopoConfigmapName = "topology_configmap_name" KeyTopoConfigmapNamespace = "topology_configmap_namespace" diff --git a/pkg/common/types.go b/pkg/common/types.go index 6684326..90739f1 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -58,7 +58,7 @@ func (e *HTTPError) Error() string { } type Provider interface { - GetCredentials(*Credentials) (interface{}, error) + GetCredentials(map[string]string) (interface{}, error) GetComputeInstances(context.Context, Engine) ([]ComputeInstances, error) GenerateTopologyConfig(context.Context, interface{}, int, []ComputeInstances) (*Vertex, error) } @@ -67,69 +67,58 @@ type Engine interface { GenerateOutput(context.Context, *Vertex, map[string]string) ([]byte, error) } -type Payload struct { - Nodes []ComputeInstances `json:"nodes"` - Creds *Credentials `json:"creds,omitempty"` // access credentials +type TopologyRequest struct { + Provider provider `json:"provider"` + Engine engine `json:"engine"` + Nodes []ComputeInstances `json:"nodes"` } -type ComputeInstances struct { - Region string `json:"region"` - Instances map[string]string `json:"instances"` // : map +type provider struct { + Name string `json:"name"` + Creds map[string]string `json:"creds"` // access credentials } -type Credentials struct { - AWS *AWSCredentials `yaml:"aws,omitempty" json:"aws,omitempty"` // AWS credentials - OCI *OCICredentials `yaml:"oci,omitempty" json:"oci,omitempty"` // OCI credentials +type engine struct { + Name string `json:"name"` + Params map[string]string `json:"params"` // access credentials } -type AWSCredentials struct { - AccessKeyId string `yaml:"access_key_id" json:"access_key_id"` - SecretAccessKey string `yaml:"secret_access_key" json:"secret_access_key"` - Token string `yaml:"token,omitempty" json:"token,omitempty"` // token is optional +type ComputeInstances struct { + Region string `json:"region"` + Instances map[string]string `json:"instances"` // : map } -type OCICredentials struct { - TenancyID string `yaml:"tenancy_id" json:"tenancy_id"` - UserID string `yaml:"user_id" json:"user_id"` - Region string `yaml:"region" json:"region"` - Fingerprint string `yaml:"fingerprint" json:"fingerprint"` - PrivateKey string `yaml:"private_key" json:"private_key"` - Passphrase string `yaml:"passphrase,omitempty" json:"passphrase,omitempty"` // passphrase is optional +func NewTopologyRequest(prv string, creds map[string]string, eng string, params map[string]string) *TopologyRequest { + return &TopologyRequest{ + Provider: provider{ + Name: prv, + Creds: creds, + }, + Engine: engine{ + Name: eng, + Params: params, + }, + } } -func (p *Payload) String() string { +func (p *TopologyRequest) String() string { var sb strings.Builder - - sb.WriteString(fmt.Sprintf("Payload:\n Nodes: %v\n", p.Nodes)) - if p.Creds != nil { - sb.WriteString(" Credentials:\n") - if p.Creds.AWS != nil { - var accessKeyId, secretAccessKey, token string - if len(p.Creds.AWS.AccessKeyId) != 0 { - accessKeyId = "***" - } - if len(p.Creds.AWS.SecretAccessKey) != 0 { - secretAccessKey = "***" - } - if len(p.Creds.AWS.Token) != 0 { - token = "***" - } - sb.WriteString(fmt.Sprintf(" AWS: AccessKeyID=%s SecretAccessKey=%s SessionToken=%s\n", - accessKeyId, secretAccessKey, token)) - } - if p.Creds.OCI != nil { - sb.WriteString(" OCI:\n") - sb.WriteString(fmt.Sprintf(" UserID=%s\n", p.Creds.OCI.UserID)) - sb.WriteString(fmt.Sprintf(" TenancyID=%s\n", p.Creds.OCI.TenancyID)) - sb.WriteString(fmt.Sprintf(" Region=%s\n", p.Creds.OCI.Region)) - } + sb.WriteString("TopologyRequest:\n") + sb.WriteString(fmt.Sprintf(" Provider: %s\n", p.Provider.Name)) + sb.WriteString(" Credentials: ") + for key := range p.Provider.Creds { + sb.WriteString(fmt.Sprintf("%s=***,", key)) } + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf(" Engine: %s\n", p.Engine.Name)) + sb.WriteString(fmt.Sprintf(" Parameters: %v\n", p.Engine.Params)) + sb.WriteString(fmt.Sprintf(" Nodes: %s\n", p.Nodes)) return sb.String() } -func GetPayload(body []byte) (*Payload, error) { - var payload Payload +func GetTopologyRequest(body []byte) (*TopologyRequest, error) { + var payload TopologyRequest if len(body) == 0 { return &payload, nil @@ -139,13 +128,5 @@ func GetPayload(body []byte) (*Payload, error) { return nil, fmt.Errorf("failed to parse payload: %v", err) } - if payload.Creds != nil { - if payload.Creds.AWS != nil { - if len(payload.Creds.AWS.AccessKeyId) == 0 || len(payload.Creds.AWS.SecretAccessKey) == 0 { - return nil, fmt.Errorf("invalid payload: must provide access_key_id and secret_access_key for AWS") - } - } - } - return &payload, nil } diff --git a/pkg/common/types_test.go b/pkg/common/types_test.go index fc34545..d858c3f 100644 --- a/pkg/common/types_test.go +++ b/pkg/common/types_test.go @@ -26,14 +26,18 @@ func TestPayload(t *testing.T) { testCases := []struct { name string input string - payload *Payload + payload *TopologyRequest print string err string }{ { name: "Case 1: no input", - payload: &Payload{}, - print: `Payload: + payload: &TopologyRequest{}, + print: `TopologyRequest: + Provider: + Credentials: + Engine: + Parameters: map[] Nodes: [] `, }, @@ -43,36 +47,26 @@ func TestPayload(t *testing.T) { "nodes": 5 } `, - err: "failed to parse payload: json: cannot unmarshal number into Go struct field Payload.nodes of type []common.ComputeInstances", + err: "failed to parse payload: json: cannot unmarshal number into Go struct field TopologyRequest.nodes of type []common.ComputeInstances", }, { - name: "Case 3: invalid creds", + name: "Case 3: valid input", input: ` { - "nodes": [ - { - "region": "region1", - "instances": { - "instance1": "node1", - "instance2": "node2", - "instance3": "node3" - } - } - ], - "creds": { - "aws": { + "provider": { + "name": "aws", + "creds": { "access_key_id": "id", - "token": "token" + "secret_access_key": "secret" } - } -} -`, - err: "invalid payload: must provide access_key_id and secret_access_key for AWS", - }, - { - name: "Case 4: valid input", - input: ` -{ + }, + "engine": { + "name": "slurm", + "params": { + "plugin": "topology/block", + "block_sizes": "30,120" + } + }, "nodes": [ { "region": "region1", @@ -90,16 +84,24 @@ func TestPayload(t *testing.T) { "instance6": "node6" } } - ], - "creds": { - "aws": { - "access_key_id": "id", - "secret_access_key": "secret" - } - } + ] } `, - payload: &Payload{ + payload: &TopologyRequest{ + Provider: provider{ + Name: "aws", + Creds: map[string]string{ + "access_key_id": "id", + "secret_access_key": "secret", + }, + }, + Engine: engine{ + Name: "slurm", + Params: map[string]string{ + "plugin": "topology/block", + "block_sizes": "30,120", + }, + }, Nodes: []ComputeInstances{ { Region: "region1", @@ -118,24 +120,20 @@ func TestPayload(t *testing.T) { }, }, }, - Creds: &Credentials{ - AWS: &AWSCredentials{ - AccessKeyId: "id", - SecretAccessKey: "secret", - }, - }, }, - print: `Payload: + print: `TopologyRequest: + Provider: aws + Credentials: access_key_id=***,secret_access_key=***, + Engine: slurm + Parameters: map[block_sizes:30,120 plugin:topology/block] Nodes: [{region1 map[instance1:node1 instance2:node2 instance3:node3]} {region2 map[instance4:node4 instance5:node5 instance6:node6]}] - Credentials: - AWS: AccessKeyID=*** SecretAccessKey=*** SessionToken= `, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - payload, err := GetPayload([]byte(tc.input)) + payload, err := GetTopologyRequest([]byte(tc.input)) if len(tc.err) != 0 { require.EqualError(t, err, tc.err) } else { diff --git a/pkg/config/config.go b/pkg/config/config.go index a94d8dd..206334e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -25,7 +25,6 @@ import ( "gopkg.in/yaml.v3" "k8s.io/klog/v2" - "github.com/NVIDIA/topograph/pkg/common" "github.com/NVIDIA/topograph/pkg/utils" ) @@ -39,7 +38,7 @@ type Config struct { Env map[string]string `yaml:"env"` // derived - Credentials common.Credentials + Credentials map[string]string } type Endpoint struct { diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 1cfba53..b5b2322 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -25,7 +25,13 @@ import ( "github.com/stretchr/testify/require" ) -var configTemplate = ` +const ( + credentials = ` +access_key_id: id +secret_access_key: key +` + + configTemplate = ` http: port: 49021 ssl: true @@ -34,10 +40,12 @@ ssl: cert: %s key: %s ca_cert: %s +credentials_path: %s env: SLURM_CONF: /etc/slurm/config.yaml PATH: /a/b/c ` +) func TestConfig(t *testing.T) { file, err := os.CreateTemp("", "test-cfg-*.yml") @@ -60,7 +68,16 @@ func TestConfig(t *testing.T) { defer func() { _ = os.Remove(caCert.Name()) }() defer func() { _ = caCert.Close() }() - _, err = file.WriteString(fmt.Sprintf(configTemplate, cert.Name(), key.Name(), caCert.Name())) + creds, err := os.CreateTemp("", "test-creds-*.yml") + require.NoError(t, err) + defer func() { _ = os.Remove(creds.Name()) }() + defer func() { _ = creds.Close() }() + credsPath := creds.Name() + + _, err = creds.WriteString(credentials) + require.NoError(t, err) + + _, err = file.WriteString(fmt.Sprintf(configTemplate, cert.Name(), key.Name(), caCert.Name(), creds.Name())) require.NoError(t, err) cfg, err := NewFromFile(file.Name()) @@ -77,6 +94,8 @@ func TestConfig(t *testing.T) { Key: key.Name(), CaCert: caCert.Name(), }, + CredsPath: &credsPath, + Credentials: map[string]string{"access_key_id": "id", "secret_access_key": "key"}, Env: map[string]string{ "SLURM_CONF": "/etc/slurm/config.yaml", "PATH": "/a/b/c", diff --git a/pkg/factory/provider.go b/pkg/factory/provider.go index f38b7ea..f513849 100644 --- a/pkg/factory/provider.go +++ b/pkg/factory/provider.go @@ -69,7 +69,7 @@ func GetTestProvider() *testProvider { return p } -func (p *testProvider) GetCredentials(_ *common.Credentials) (interface{}, error) { +func (p *testProvider) GetCredentials(_ map[string]string) (interface{}, error) { return nil, nil } diff --git a/pkg/node_observer/controller.go b/pkg/node_observer/controller.go index a1dd5a8..2b25a87 100644 --- a/pkg/node_observer/controller.go +++ b/pkg/node_observer/controller.go @@ -17,7 +17,9 @@ package node_observer import ( + "bytes" "context" + "encoding/json" "fmt" "net/http" @@ -25,6 +27,7 @@ import ( "k8s.io/klog/v2" "github.com/NVIDIA/topograph/pkg/common" + "github.com/NVIDIA/topograph/pkg/utils" ) type Controller struct { @@ -35,24 +38,29 @@ type Controller struct { } func NewController(ctx context.Context, client kubernetes.Interface, cfg *Config) (*Controller, error) { - req, err := http.NewRequest("POST", cfg.TopologyGeneratorURL, nil) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %v", err) + var f utils.HttpRequestFunc = func() (*http.Request, error) { + params := map[string]string{ + common.KeyTopoConfigPath: cfg.TopologyConfigmap.Filename, + common.KeyTopoConfigmapName: cfg.TopologyConfigmap.Name, + common.KeyTopoConfigmapNamespace: cfg.TopologyConfigmap.Namespace, + } + payload := common.NewTopologyRequest(cfg.Provider, nil, cfg.Engine, params) + data, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to parse payload: %v", err) + } + req, err := http.NewRequest("POST", cfg.TopologyGeneratorURL, bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + return req, nil } - - q := req.URL.Query() - q.Add(common.KeyProvider, cfg.Provider) - q.Add(common.KeyEngine, cfg.Engine) - q.Add(common.KeyTopoConfigPath, cfg.TopologyConfigmap.Filename) - q.Add(common.KeyTopoConfigmapName, cfg.TopologyConfigmap.Name) - q.Add(common.KeyTopoConfigmapNamespace, cfg.TopologyConfigmap.Namespace) - req.URL.RawQuery = q.Encode() - return &Controller{ ctx: ctx, client: client, cfg: cfg, - nodeInformer: NewNodeInformer(ctx, client, cfg.NodeLabels, req), + nodeInformer: NewNodeInformer(ctx, client, cfg.NodeLabels, f), }, nil } diff --git a/pkg/node_observer/node_informer.go b/pkg/node_observer/node_informer.go index 00b7107..0e84f2c 100644 --- a/pkg/node_observer/node_informer.go +++ b/pkg/node_observer/node_informer.go @@ -18,7 +18,6 @@ package node_observer import ( "context" - "net/http" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -34,11 +33,11 @@ import ( type NodeInformer struct { ctx context.Context client kubernetes.Interface - req *http.Request + reqFunc utils.HttpRequestFunc factory informers.SharedInformerFactory } -func NewNodeInformer(ctx context.Context, client kubernetes.Interface, nodeLabels map[string]string, req *http.Request) *NodeInformer { +func NewNodeInformer(ctx context.Context, client kubernetes.Interface, nodeLabels map[string]string, reqFunc utils.HttpRequestFunc) *NodeInformer { klog.Infof("Configuring node informer with labels %v", nodeLabels) listOptionsFunc := func(options *metav1.ListOptions) { options.LabelSelector = labels.Set(nodeLabels).AsSelector().String() @@ -46,7 +45,7 @@ func NewNodeInformer(ctx context.Context, client kubernetes.Interface, nodeLabel return &NodeInformer{ ctx: ctx, client: client, - req: req, + reqFunc: reqFunc, factory: informers.NewSharedInformerFactoryWithOptions(client, 0, informers.WithTweakListOptions(listOptionsFunc)), } } @@ -87,7 +86,7 @@ func (n *NodeInformer) Stop(_ error) { } func (n *NodeInformer) SendRequest() { - _, _, err := utils.HttpRequestWithRetries(n.req) + _, _, err := utils.HttpRequestWithRetries(n.reqFunc) if err != nil { klog.Errorf("failed to send HTTP request: %v", err) } diff --git a/pkg/providers/aws/imds.go b/pkg/providers/aws/imds.go index 948771b..d076a34 100644 --- a/pkg/providers/aws/imds.go +++ b/pkg/providers/aws/imds.go @@ -47,14 +47,16 @@ type Creds struct { } func getToken() (string, error) { - req, err := http.NewRequest("PUT", IMDS_TOKEN_URL, nil) - if err != nil { - return "", fmt.Errorf("failed to create HTTP request: %v", err) - } - - req.Header.Add("X-aws-ec2-metadata-token-ttl-seconds", "21600") + var f utils.HttpRequestFunc = (func() (*http.Request, error) { + req, err := http.NewRequest("PUT", IMDS_TOKEN_URL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %v", err) + } + req.Header.Add("X-aws-ec2-metadata-token-ttl-seconds", "21600") + return req, nil + }) - _, data, err := utils.HttpRequest(req) + _, data, err := utils.HttpRequest(f) if err != nil { return "", fmt.Errorf("failed to send HTTP request: %v", err) } @@ -79,17 +81,19 @@ func getMetadata(path string) ([]byte, error) { url := fmt.Sprintf("%s/%s", IMDS_URL, path) klog.V(4).Infof("Requesting URL %s", url) - req, err := http.NewRequest("GET", url, nil) - if err != nil { - return nil, fmt.Errorf("failed to create HTTP request: %v", err) - } - - err = addToken(req) - if err != nil { - return nil, err + var f utils.HttpRequestFunc = func() (*http.Request, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %v", err) + } + err = addToken(req) + if err != nil { + return nil, err + } + return req, nil } - resp, data, err := utils.HttpRequest(req) + resp, data, err := utils.HttpRequest(f) if err != nil { return nil, fmt.Errorf("failed to send HTTP request: %v", err) } diff --git a/pkg/providers/aws/instance_topology.go b/pkg/providers/aws/instance_topology.go index d37c28d..d911da6 100644 --- a/pkg/providers/aws/instance_topology.go +++ b/pkg/providers/aws/instance_topology.go @@ -32,7 +32,7 @@ import ( var defaultPageSize int32 = 100 -func GenerateInstanceTopology(ctx context.Context, creds *common.AWSCredentials, pageSize int32, cis []common.ComputeInstances) ([]types.InstanceTopology, error) { +func GenerateInstanceTopology(ctx context.Context, creds *Credentials, pageSize int32, cis []common.ComputeInstances) ([]types.InstanceTopology, error) { var err error topology := []types.InstanceTopology{} for _, ci := range cis { @@ -44,7 +44,7 @@ func GenerateInstanceTopology(ctx context.Context, creds *common.AWSCredentials, return topology, nil } -func generateInstanceTopology(ctx context.Context, creds *common.AWSCredentials, pageSize int32, ci *common.ComputeInstances, topology []types.InstanceTopology) ([]types.InstanceTopology, error) { +func generateInstanceTopology(ctx context.Context, creds *Credentials, pageSize int32, ci *common.ComputeInstances, topology []types.InstanceTopology) ([]types.InstanceTopology, error) { if len(ci.Region) == 0 { return nil, fmt.Errorf("must specify region to query instance topology") } diff --git a/pkg/providers/aws/provider.go b/pkg/providers/aws/provider.go index e96a34a..488eb3b 100644 --- a/pkg/providers/aws/provider.go +++ b/pkg/providers/aws/provider.go @@ -32,17 +32,29 @@ import ( type Provider struct{} +type Credentials struct { + AccessKeyId string + SecretAccessKey string + Token string // token is optional +} + func GetProvider() (*Provider, error) { return &Provider{}, nil } -func (p *Provider) GetCredentials(creds *common.Credentials) (interface{}, error) { - if creds != nil && creds.AWS != nil { - return creds.AWS, nil - } - +func (p *Provider) GetCredentials(creds map[string]string) (interface{}, error) { var accessKeyID, secretAccessKey, sessionToken string - if len(os.Getenv("AWS_ACCESS_KEY_ID")) != 0 && len(os.Getenv("AWS_SECRET_ACCESS_KEY")) != 0 { + + if len(creds) != 0 { + klog.Infof("Using provided AWS credentials") + if accessKeyID = creds["access_key_id"]; len(accessKeyID) == 0 { + return nil, fmt.Errorf("credentials error: missing access_key_id") + } + if secretAccessKey = creds["secret_access_key"]; len(secretAccessKey) == 0 { + return nil, fmt.Errorf("credentials error: missing secret_access_key") + } + sessionToken = creds["token"] + } else if len(os.Getenv("AWS_ACCESS_KEY_ID")) != 0 && len(os.Getenv("AWS_SECRET_ACCESS_KEY")) != 0 { klog.Infof("Using shell AWS credentials") accessKeyID = os.Getenv("AWS_ACCESS_KEY_ID") secretAccessKey = os.Getenv("AWS_SECRET_ACCESS_KEY") @@ -57,7 +69,8 @@ func (p *Provider) GetCredentials(creds *common.Credentials) (interface{}, error secretAccessKey = nodeCreds.SecretAccessKey sessionToken = nodeCreds.Token } - return &common.AWSCredentials{ + + return &Credentials{ AccessKeyId: accessKeyID, SecretAccessKey: secretAccessKey, Token: sessionToken, @@ -96,7 +109,7 @@ func (p *Provider) GetComputeInstances(ctx context.Context, engine common.Engine } func (p *Provider) GenerateTopologyConfig(ctx context.Context, cr interface{}, pageSize int, instances []common.ComputeInstances) (*common.Vertex, error) { - creds := cr.(*common.AWSCredentials) + creds := cr.(*Credentials) topology, err := GenerateInstanceTopology(ctx, creds, int32(pageSize), instances) if err != nil { return nil, err diff --git a/pkg/providers/cw/provider.go b/pkg/providers/cw/provider.go index 3963657..dafa8ba 100644 --- a/pkg/providers/cw/provider.go +++ b/pkg/providers/cw/provider.go @@ -36,7 +36,7 @@ func GetProvider() (*Provider, error) { return &Provider{}, nil } -func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) { +func (p *Provider) GetCredentials(_ map[string]string) (interface{}, error) { return nil, nil } diff --git a/pkg/providers/gcp/provider.go b/pkg/providers/gcp/provider.go index 271d11a..f56d845 100644 --- a/pkg/providers/gcp/provider.go +++ b/pkg/providers/gcp/provider.go @@ -34,7 +34,7 @@ func GetProvider() (*Provider, error) { return &Provider{}, nil } -func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) { +func (p *Provider) GetCredentials(_ map[string]string) (interface{}, error) { return nil, nil } diff --git a/pkg/providers/oci/provider.go b/pkg/providers/oci/provider.go index cc4bd8c..80d82bf 100644 --- a/pkg/providers/oci/provider.go +++ b/pkg/providers/oci/provider.go @@ -36,14 +36,28 @@ func GetProvider() (*Provider, error) { return &Provider{}, nil } -func (p *Provider) GetCredentials(creds *common.Credentials) (interface{}, error) { - if creds != nil && creds.OCI != nil { +func (p *Provider) GetCredentials(creds map[string]string) (interface{}, error) { + if len(creds) != 0 { + var tenancyID, userID, region, fingerprint, privateKey, passphrase string klog.Info("Using provided credentials") - var passphrase *string - if creds.OCI.Passphrase != "" { - passphrase = &creds.OCI.Passphrase + if tenancyID = creds["tenancy_id"]; len(tenancyID) == 0 { + return nil, fmt.Errorf("credentials error: missing tenancy_id") } - return OCICommon.NewRawConfigurationProvider(creds.OCI.TenancyID, creds.OCI.UserID, creds.OCI.Region, creds.OCI.Fingerprint, creds.OCI.PrivateKey, passphrase), nil + if userID = creds["user_id"]; len(userID) == 0 { + return nil, fmt.Errorf("credentials error: missing user_id") + } + if region = creds["region"]; len(region) == 0 { + return nil, fmt.Errorf("credentials error: missing region") + } + if fingerprint = creds["fingerprint"]; len(fingerprint) == 0 { + return nil, fmt.Errorf("credentials error: missing fingerprint") + } + if privateKey = creds["private_key"]; len(privateKey) == 0 { + return nil, fmt.Errorf("credentials error: missing private_key") + } + passphrase = creds["passphrase"] + + return OCICommon.NewRawConfigurationProvider(tenancyID, userID, region, fingerprint, privateKey, &passphrase), nil } klog.Info("No credentials provided, trying default configuration provider") diff --git a/pkg/server/engine.go b/pkg/server/engine.go index df2652e..06f2983 100644 --- a/pkg/server/engine.go +++ b/pkg/server/engine.go @@ -34,7 +34,7 @@ type asyncController struct { } func processRequest(item interface{}) (interface{}, *common.HTTPError) { - tr := item.(*TopologyRequest) + tr := item.(*common.TopologyRequest) var code int start := time.Now() @@ -44,21 +44,21 @@ func processRequest(item interface{}) (interface{}, *common.HTTPError) { } else { code = http.StatusOK } - metrics.Add(tr.provider, tr.engine, code, time.Since(start)) + metrics.Add(tr.Provider.Name, tr.Engine.Name, code, time.Since(start)) return ret, err } -func processTopologyRequest(tr *TopologyRequest) ([]byte, *common.HTTPError) { - klog.InfoS("Creating topology config", "provider", tr.provider, "engine", tr.engine) +func processTopologyRequest(tr *common.TopologyRequest) ([]byte, *common.HTTPError) { + klog.InfoS("Creating topology config", "provider", tr.Provider.Name, "engine", tr.Engine.Name) - eng, httpErr := factory.GetEngine(tr.engine) + eng, httpErr := factory.GetEngine(tr.Engine.Name) if httpErr != nil { klog.Error(httpErr.Error()) return nil, httpErr } - prv, httpErr := factory.GetProvider(tr.provider) + prv, httpErr := factory.GetProvider(tr.Provider.Name) if httpErr != nil { klog.Error(httpErr.Error()) return nil, httpErr @@ -67,7 +67,7 @@ func processTopologyRequest(tr *TopologyRequest) ([]byte, *common.HTTPError) { ctx := context.TODO() // if the instance/node mapping is not provided in the payload, get the mapping from the provider - computeInstances := tr.payload.Nodes + computeInstances := tr.Nodes if len(computeInstances) == 0 { var err error computeInstances, err = prv.GetComputeInstances(ctx, eng) @@ -76,7 +76,7 @@ func processTopologyRequest(tr *TopologyRequest) ([]byte, *common.HTTPError) { } } - creds, err := prv.GetCredentials(checkCredentials(tr.payload.Creds, &srv.cfg.Credentials)) + creds, err := prv.GetCredentials(checkCredentials(tr.Provider.Creds, srv.cfg.Credentials)) if err != nil { klog.Error(err.Error()) return nil, common.NewHTTPError(http.StatusUnauthorized, err.Error()) @@ -94,7 +94,7 @@ func processTopologyRequest(tr *TopologyRequest) ([]byte, *common.HTTPError) { return nil, common.NewHTTPError(http.StatusInternalServerError, err.Error()) } - data, err := eng.GenerateOutput(ctx, root, tr.params) + data, err := eng.GenerateOutput(ctx, root, tr.Engine.Params) if err != nil { klog.Error(err.Error()) return nil, common.NewHTTPError(http.StatusInternalServerError, err.Error()) @@ -103,8 +103,8 @@ func processTopologyRequest(tr *TopologyRequest) ([]byte, *common.HTTPError) { return data, nil } -func checkCredentials(payloadCreds, cfgCreds *common.Credentials) *common.Credentials { - if payloadCreds != nil { +func checkCredentials(payloadCreds, cfgCreds map[string]string) map[string]string { + if len(payloadCreds) != 0 { return payloadCreds } return cfgCreds diff --git a/pkg/server/grpc_client.go b/pkg/server/grpc_client.go index 6fea126..409a9d7 100644 --- a/pkg/server/grpc_client.go +++ b/pkg/server/grpc_client.go @@ -28,7 +28,7 @@ import ( pb "github.com/NVIDIA/topograph/pkg/protos" ) -func forwardRequest(ctx context.Context, tr *TopologyRequest, url string, cis []common.ComputeInstances) (*common.Vertex, error) { +func forwardRequest(ctx context.Context, tr *common.TopologyRequest, url string, cis []common.ComputeInstances) (*common.Vertex, error) { klog.Infof("Forwarding request to %s", url) conn, err := grpc.NewClient(url, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -48,7 +48,7 @@ func forwardRequest(ctx context.Context, tr *TopologyRequest, url string, cis [] klog.Infof("Getting topology for instances %v", ids) response, err := client.DescribeTopology(ctx, &pb.TopologyRequest{ - Provider: tr.provider, + Provider: tr.Provider.Name, Region: "", InstanceIds: ids, }) diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index c33cd99..265564a 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -21,7 +21,6 @@ import ( "fmt" "io" "net/http" - "net/url" "github.com/prometheus/client_golang/prometheus/promhttp" "k8s.io/klog/v2" @@ -38,13 +37,6 @@ type HttpServer struct { async *asyncController } -type TopologyRequest struct { - provider string - engine string - params map[string]string - payload *common.Payload -} - var srv *HttpServer func InitHttpServer(ctx context.Context, cfg *config.Config) { @@ -106,7 +98,7 @@ func generate(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(uid)) } -func readRequest(w http.ResponseWriter, r *http.Request) *TopologyRequest { +func readRequest(w http.ResponseWriter, r *http.Request) *common.TopologyRequest { if r.Method != http.MethodPost { http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) return nil @@ -119,21 +111,13 @@ func readRequest(w http.ResponseWriter, r *http.Request) *TopologyRequest { } defer func() { _ = r.Body.Close() }() - tr := &TopologyRequest{} - - tr.payload, err = common.GetPayload(body) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return nil - } - - tr.provider, tr.engine, tr.params, err = parseQuery(r.URL.Query()) + tr, err := common.GetTopologyRequest(body) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return nil } - klog.InfoS("Topology request", "provider", tr.provider, "engine", tr.engine, "params", tr.params, "payload", tr.payload.String()) + klog.Info(tr.String()) if err = validate(tr); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -143,44 +127,23 @@ func readRequest(w http.ResponseWriter, r *http.Request) *TopologyRequest { return tr } -func parseQuery(vals url.Values) (string, string, map[string]string, error) { - params := make(map[string]string) - var provider, engine string - - for key, arr := range vals { - switch key { - case common.KeyProvider: - provider = arr[0] - case common.KeyEngine: - engine = arr[0] - default: - params[key] = arr[0] - } - } - - if len(provider) == 0 { - return "", "", nil, fmt.Errorf("missing provider URL query parameter") - } - if len(engine) == 0 { - return "", "", nil, fmt.Errorf("missing engine URL query parameter") - } - - return provider, engine, params, nil -} - -func validate(tr *TopologyRequest) error { - switch tr.provider { +func validate(tr *common.TopologyRequest) error { + switch tr.Provider.Name { case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest: //nop + case "": + return fmt.Errorf("missing provider name") default: - return fmt.Errorf("unsupported provider %s", tr.provider) + return fmt.Errorf("unsupported provider %s", tr.Provider.Name) } - switch tr.engine { + switch tr.Engine.Name { + case "": + return fmt.Errorf("missing engine name") case common.EngineK8S: for _, key := range []string{common.KeyTopoConfigPath, common.KeyTopoConfigmapName, common.KeyTopoConfigmapNamespace} { - if _, ok := tr.params[key]; !ok { - return fmt.Errorf("missing %q URL query parameter", key) + if _, ok := tr.Engine.Params[key]; !ok { + return fmt.Errorf("missing %q parameter", key) } } } diff --git a/pkg/utils/http.go b/pkg/utils/http.go index 54bd072..185fa16 100644 --- a/pkg/utils/http.go +++ b/pkg/utils/http.go @@ -40,8 +40,14 @@ var ( } ) +type HttpRequestFunc func() (*http.Request, error) + // HttpRequest sends HTTP requests and returns HTTP response -func HttpRequest(req *http.Request) (*http.Response, []byte, error) { +func HttpRequest(f HttpRequestFunc) (*http.Response, []byte, error) { + req, err := f() + if err != nil { + return nil, nil, err + } klog.V(4).Infof("Sending HTTP request %s", req.URL.String()) client := &http.Client{} resp, err := client.Do(req) @@ -63,10 +69,10 @@ func HttpRequest(req *http.Request) (*http.Response, []byte, error) { } // HttpRequestWithRetries sends HTTP requests and returns HTTP response; retries if needed -func HttpRequestWithRetries(req *http.Request) (resp *http.Response, body []byte, err error) { - klog.V(4).Infof("Sending HTTP request %s with retries", req.URL.String()) +func HttpRequestWithRetries(f HttpRequestFunc) (resp *http.Response, body []byte, err error) { + klog.V(4).Infof("Sending HTTP request with retries") for r := 1; r <= retries; r++ { - resp, body, err = HttpRequest(req) + resp, body, err = HttpRequest(f) if err == nil || !retryHttpCodes[resp.StatusCode] { break }