diff --git a/pkg/common/const.go b/pkg/common/const.go index 1ad3275..a5c9cdc 100644 --- a/pkg/common/const.go +++ b/pkg/common/const.go @@ -21,6 +21,7 @@ const ( ProviderOCI = "oci" ProviderGCP = "gcp" ProviderCW = "cw" + ProviderBM = "baremetal" ProviderTest = "test" EngineSLURM = "slurm" diff --git a/pkg/common/types.go b/pkg/common/types.go index 6684326..e1075b7 100644 --- a/pkg/common/types.go +++ b/pkg/common/types.go @@ -31,6 +31,7 @@ type Vertex struct { Name string ID string Vertices map[string]*Vertex + Metadata map[string]string } func (v *Vertex) String() string { diff --git a/pkg/protos/topology.pb.go b/pkg/protos/topology.pb.go index 8c8029e..b487737 100644 --- a/pkg/protos/topology.pb.go +++ b/pkg/protos/topology.pb.go @@ -15,8 +15,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.34.2 -// protoc v5.27.0 +// protoc-gen-go v1.35.1 +// protoc v3.12.4 // source: topology.proto package protos @@ -47,11 +47,9 @@ type TopologyRequest struct { func (x *TopologyRequest) Reset() { *x = TopologyRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_topology_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_topology_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *TopologyRequest) String() string { @@ -62,7 +60,7 @@ func (*TopologyRequest) ProtoMessage() {} func (x *TopologyRequest) ProtoReflect() protoreflect.Message { mi := &file_topology_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -108,11 +106,9 @@ type TopologyResponse struct { func (x *TopologyResponse) Reset() { *x = TopologyResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_topology_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_topology_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *TopologyResponse) String() string { @@ -123,7 +119,7 @@ func (*TopologyResponse) ProtoMessage() {} func (x *TopologyResponse) ProtoReflect() protoreflect.Message { mi := &file_topology_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -161,11 +157,9 @@ type Instance struct { func (x *Instance) Reset() { *x = Instance{} - if protoimpl.UnsafeEnabled { - mi := &file_topology_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_topology_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Instance) String() string { @@ -176,7 +170,7 @@ func (*Instance) ProtoMessage() {} func (x *Instance) ProtoReflect() protoreflect.Message { mi := &file_topology_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -314,44 +308,6 @@ func file_topology_proto_init() { if File_topology_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_topology_proto_msgTypes[0].Exporter = func(v any, i int) any { - switch v := v.(*TopologyRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_topology_proto_msgTypes[1].Exporter = func(v any, i int) any { - switch v := v.(*TopologyResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_topology_proto_msgTypes[2].Exporter = func(v any, i int) any { - switch v := v.(*Instance); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ diff --git a/pkg/protos/topology_grpc.pb.go b/pkg/protos/topology_grpc.pb.go index 637747c..a5ec5ea 100644 --- a/pkg/protos/topology_grpc.pb.go +++ b/pkg/protos/topology_grpc.pb.go @@ -16,7 +16,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v5.27.0 +// - protoc v3.12.4 // source: topology.proto package protos diff --git a/pkg/providers/baremetal/mnnvl.go b/pkg/providers/baremetal/mnnvl.go new file mode 100644 index 0000000..14dc84a --- /dev/null +++ b/pkg/providers/baremetal/mnnvl.go @@ -0,0 +1,101 @@ +package baremetal + +import ( + "bufio" + "context" + "fmt" + "github.com/NVIDIA/topograph/pkg/common" + "github.com/NVIDIA/topograph/pkg/utils" + "strconv" + "strings" +) + +// domain contains map of each domainID(clusterUUID) -> list of nodeNames in that domain +// Each domain will be a separate NVL Domain +type domain struct { + nodeMap map[string]bool // nodeName: true +} + +// getNodeList retrieves all the nodenames on the cluster +func getNodeList(cis []common.ComputeInstances) []string { + nodes := []string{} + for _, ci := range cis { + for _, node := range ci.Instances { + nodes = append(nodes, node) + } + } + return nodes +} + +// Check if domainID exists in the map +func domainIDExists(id string, domainMap map[string]domain) bool { + if _, exists := domainMap[id]; exists { + return true + } + return false +} + +// getClusterOutput reads output from nodeInfo and populates the structs +func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []string, cmd string) error { + args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd} + stdout, err := utils.Exec(ctx, "pdsh", args, nil) + if err != nil { + return fmt.Errorf("Exec error while pdsh\n") + } + + scanner := bufio.NewScanner(stdout) + for scanner.Scan() { + nodeLine := scanner.Text() + arr := strings.Split(nodeLine, ":") + nodeName := arr[0] + clusterUUID := strings.TrimSpace(arr[2]) + if !domainIDExists(clusterUUID, domainMap) { + domainMap[clusterUUID] = domain{ + nodeMap: make(map[string]bool), + } + } + nodeMap := domainMap[clusterUUID].nodeMap + nodeMap[nodeName] = true + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("Scanner error while reading pdsh output\n") + } + return nil +} +func toGraph(domainMap map[string]domain) *common.Vertex { + root := &common.Vertex{ + Vertices: make(map[string]*common.Vertex), + Metadata: make(map[string]string), + } + blockSize := -1 + for domainName, domain := range domainMap { + tree := &common.Vertex{ + ID: domainName, + Vertices: make(map[string]*common.Vertex), + } + for node := range domain.nodeMap { + tree.Vertices[node] = &common.Vertex{Name: node, ID: node} + if blockSize == -1 { + blockSize = len(domain.nodeMap) + } else { + fmt.Printf("blockSize different between NVL domains") + } + } + root.Vertices[domainName] = tree + } + // add root metadata + root.Metadata["engine"] = "slurm" + root.Metadata["plugin"] = "topology/block" + root.Metadata["blocksize"] = strconv.Itoa(blockSize) + return root +} + +func generateTopologyConfig(ctx context.Context, cis []common.ComputeInstances) (*common.Vertex, error) { + domainMap := make(map[string]domain) // domainID: domain + nodes := getNodeList(cis) + err := getClusterOutput(ctx, domainMap, nodes, "nvidia-smi -q | grep ClusterUUID") + if err != nil { + return nil, fmt.Errorf("getClusterOutput failed: %v\n", err) + } + return toGraph(domainMap), nil +} diff --git a/pkg/providers/baremetal/provider.go b/pkg/providers/baremetal/provider.go new file mode 100644 index 0000000..9d248cd --- /dev/null +++ b/pkg/providers/baremetal/provider.go @@ -0,0 +1,49 @@ +package baremetal + +import ( + "context" + "fmt" + + "k8s.io/klog/v2" + + "github.com/NVIDIA/topograph/pkg/common" + "github.com/NVIDIA/topograph/pkg/engines/slurm" +) + +type Provider struct{} + +func GetProvider() (*Provider, error) { + return &Provider{}, nil +} + +func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) { + return nil, nil +} + +func (p *Provider) GetComputeInstances(ctx context.Context, engine common.Engine) ([]common.ComputeInstances, error) { + klog.InfoS("Getting compute instances", "provider", common.ProviderBM, "engine", engine) + + switch engine.(type) { + case *slurm.SlurmEngine: + nodes, err := slurm.GetNodeList(ctx) + if err != nil { + return nil, err + } + i2n := make(map[string]string) + for _, node := range nodes { + i2n[node] = node + } + return []common.ComputeInstances{{Instances: i2n}}, nil + default: + return nil, fmt.Errorf("unsupported engine %q", engine) + } +} + +func (p *Provider) GenerateTopologyConfig(ctx context.Context, _ interface{}, _ int, instances []common.ComputeInstances) (*common.Vertex, error) { + if len(instances) > 1 { + return nil, fmt.Errorf("On-prem does not support multi-region topology requests") + } + + //call mnnvl code from here + return generateTopologyConfig(ctx, instances) +} diff --git a/pkg/server/http_server.go b/pkg/server/http_server.go index c33cd99..7d7a975 100644 --- a/pkg/server/http_server.go +++ b/pkg/server/http_server.go @@ -170,7 +170,7 @@ func parseQuery(vals url.Values) (string, string, map[string]string, error) { func validate(tr *TopologyRequest) error { switch tr.provider { - case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest: + case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest, common.ProviderBM: //nop default: return fmt.Errorf("unsupported provider %s", tr.provider)