Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discover MNNVL topology with single blocksize #9

Merged
merged 6 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pkg/common/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const (
ProviderOCI = "oci"
ProviderGCP = "gcp"
ProviderCW = "cw"
ProviderBM = "baremetal"
ProviderTest = "test"

EngineSLURM = "slurm"
Expand Down
1 change: 1 addition & 0 deletions pkg/common/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type Vertex struct {
Name string
ID string
Vertices map[string]*Vertex
Metadata map[string]string
}

func (v *Vertex) String() string {
Expand Down
72 changes: 14 additions & 58 deletions pkg/protos/topology.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pkg/protos/topology_grpc.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

101 changes: 101 additions & 0 deletions pkg/providers/baremetal/mnnvl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package baremetal

import (
"bufio"
"context"
"fmt"
"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/utils"
"strconv"
"strings"
)

// domain contains map of each domainID(clusterUUID) -> list of nodeNames in that domain
// Each domain will be a separate NVL Domain
type domain struct {
nodeMap map[string]bool // nodeName: true
}

// getNodeList retrieves all the nodenames on the cluster
func getNodeList(cis []common.ComputeInstances) []string {
nodes := []string{}
for _, ci := range cis {
for _, node := range ci.Instances {
nodes = append(nodes, node)
}
}
return nodes
}

// Check if domainID exists in the map
func domainIDExists(id string, domainMap map[string]domain) bool {
if _, exists := domainMap[id]; exists {
return true
}
return false
}

// getClusterOutput reads output from nodeInfo and populates the structs
func getClusterOutput(ctx context.Context, domainMap map[string]domain, nodes []string, cmd string) error {
args := []string{"-R", "ssh", "-w", strings.Join(nodes, ","), cmd}
stdout, err := utils.Exec(ctx, "pdsh", args, nil)
if err != nil {
return fmt.Errorf("Exec error while pdsh\n")
}

scanner := bufio.NewScanner(stdout)
for scanner.Scan() {
nodeLine := scanner.Text()
arr := strings.Split(nodeLine, ":")
nodeName := arr[0]
clusterUUID := strings.TrimSpace(arr[2])
if !domainIDExists(clusterUUID, domainMap) {
domainMap[clusterUUID] = domain{
nodeMap: make(map[string]bool),
}
}
nodeMap := domainMap[clusterUUID].nodeMap
nodeMap[nodeName] = true
}
if err := scanner.Err(); err != nil {
return fmt.Errorf("Scanner error while reading pdsh output\n")
}
return nil
}
func toGraph(domainMap map[string]domain) *common.Vertex {
root := &common.Vertex{
Vertices: make(map[string]*common.Vertex),
Metadata: make(map[string]string),
}
blockSize := -1
for domainName, domain := range domainMap {
tree := &common.Vertex{
ID: domainName,
Vertices: make(map[string]*common.Vertex),
}
for node := range domain.nodeMap {
tree.Vertices[node] = &common.Vertex{Name: node, ID: node}
if blockSize == -1 {
blockSize = len(domain.nodeMap)
} else {
fmt.Printf("blockSize different between NVL domains")
}
}
root.Vertices[domainName] = tree
}
// add root metadata
root.Metadata["engine"] = "slurm"
root.Metadata["plugin"] = "topology/block"
root.Metadata["blocksize"] = strconv.Itoa(blockSize)
return root
}

func generateTopologyConfig(ctx context.Context, cis []common.ComputeInstances) (*common.Vertex, error) {
domainMap := make(map[string]domain) // domainID: domain
nodes := getNodeList(cis)
err := getClusterOutput(ctx, domainMap, nodes, "nvidia-smi -q | grep ClusterUUID")
if err != nil {
return nil, fmt.Errorf("getClusterOutput failed: %v\n", err)
}
return toGraph(domainMap), nil
}
49 changes: 49 additions & 0 deletions pkg/providers/baremetal/provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package baremetal

import (
"context"
"fmt"

"k8s.io/klog/v2"

"github.com/NVIDIA/topograph/pkg/common"
"github.com/NVIDIA/topograph/pkg/engines/slurm"
)

type Provider struct{}

func GetProvider() (*Provider, error) {
return &Provider{}, nil
}

func (p *Provider) GetCredentials(_ *common.Credentials) (interface{}, error) {
return nil, nil
}

func (p *Provider) GetComputeInstances(ctx context.Context, engine common.Engine) ([]common.ComputeInstances, error) {
klog.InfoS("Getting compute instances", "provider", common.ProviderBM, "engine", engine)

switch engine.(type) {
case *slurm.SlurmEngine:
nodes, err := slurm.GetNodeList(ctx)
if err != nil {
return nil, err
}
i2n := make(map[string]string)
for _, node := range nodes {
i2n[node] = node
}
return []common.ComputeInstances{{Instances: i2n}}, nil
default:
return nil, fmt.Errorf("unsupported engine %q", engine)
}
}

func (p *Provider) GenerateTopologyConfig(ctx context.Context, _ interface{}, _ int, instances []common.ComputeInstances) (*common.Vertex, error) {
if len(instances) > 1 {
return nil, fmt.Errorf("On-prem does not support multi-region topology requests")
}

//call mnnvl code from here
return generateTopologyConfig(ctx, instances)
}
2 changes: 1 addition & 1 deletion pkg/server/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func parseQuery(vals url.Values) (string, string, map[string]string, error) {

func validate(tr *TopologyRequest) error {
switch tr.provider {
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest:
case common.ProviderAWS, common.ProviderOCI, common.ProviderGCP, common.ProviderCW, common.ProviderTest, common.ProviderBM:
//nop
default:
return fmt.Errorf("unsupported provider %s", tr.provider)
Expand Down
Loading