Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

aws: Improve platform lookup handling #757

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ tests/unit/show_tuner_costs
tests/unit/ep_addr_list
tests/unit/mr
tests/unit/region_based_tuner
tests/unit/show_tuner_decisions
tests/unit/aws_platform_mapper

# http://www.gnu.org/software/automake
.deps/
Expand Down
3 changes: 2 additions & 1 deletion include/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ noinst_HEADERS = \
tracing_impl/nvtx.h \
internal/tuner/nccl_defaults.h \
nccl_ofi_platform.h \
nccl_ofi_ep_addr_list.h
nccl_ofi_ep_addr_list.h \
platform-aws.h
59 changes: 59 additions & 0 deletions include/platform-aws.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All rights reserved.
*
* Access helper functions from platform-aws specifically for unit
* tests. You do not want to include this file outside of
* platform-aws.c or a unit test, or you'll break linking on non-AWS
* platforms.
*/

#ifndef PLATFORM_AWS_H_
#define PLATFORM_AWS_H_

#include <stdbool.h>

#ifdef __cplusplus
extern "C" {
#endif


struct ec2_platform_data {
const char* name;
const char* regex;
const char* topology;
int default_dup_conns;
float latency;
bool gdr_required;
bool net_flush_required;
const char *default_protocol;
int domain_per_thread;
};


/*
* @brief Get the platform data map
*
* This function exists solely to test
* platform_aws_get_platform_entry() against the production data map.
*/
struct ec2_platform_data *platform_aws_get_platform_map(size_t *len);


/*
* @brief Returns platform data for current platform type, if found
*
* @input Platform type
*
* @return NULL, if no topology found
* platform data, if match found
*/
struct ec2_platform_data *platform_aws_get_platform_entry(const char *platform_type,
struct ec2_platform_data *platform_data_list,
size_t platform_data_len);


#ifdef __cplusplus
} // End extern "C"
#endif

#endif // End NCCL_OFI_H_
131 changes: 88 additions & 43 deletions src/platform-aws.c
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,14 @@

#include "nccl_ofi.h"
#include "nccl_ofi_platform.h"
#include "platform-aws.h"
#include "nccl_ofi_log.h"
#include "nccl_ofi_math.h"
#include "nccl_ofi_rdma.h"
#include "nccl_ofi_param.h"
#include "nccl_ofi_pthread.h"
#include "nccl_ofi_system.h"

struct ec2_platform_data {
const char* name;
const char* topology;
int default_dup_conns;
float latency;
bool gdr_required;
bool net_flush_required;
const char *default_protocol;
int domain_per_thread;
};

/*
* platform_data_map is an ordered list of platform entries. The
Expand All @@ -46,7 +37,8 @@ struct ec2_platform_data {
*/
static struct ec2_platform_data platform_data_map[] = {
{
.name = "^p4d.24xlarge$",
.name = "p4d.24xlarge",
.regex = NULL,
.topology = "p4d-24xl-topo.xml",
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -56,7 +48,8 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "^p4de.24xlarge$",
.name = "p4de.24xlarge",
.regex = NULL,
.topology = "p4de-24xl-topo.xml",
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -66,7 +59,8 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "^p3dn.24xlarge$",
.name = "p3dn.24xlarge",
.regex = NULL,
.topology = NULL,
.default_dup_conns = 4,
.latency = 150.0,
Expand All @@ -76,7 +70,13 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "^p5.*",
.name = "p-series",
/*
* we only want to match P5 and later, as earlier
* platforms all either need to be ignored or special
* cased.
*/
.regex = "^p([5-9]|[0-9]{2,}).*",
aws-nslick marked this conversation as resolved.
Show resolved Hide resolved
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -86,7 +86,8 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "^g5.48xlarge$",
.name = "g5.48xlarge",
.regex = NULL,
.topology = "g5.48xl-topo.xml",
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -96,7 +97,8 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 0,
},
{
.name = "^trn1.*",
.name = "trn1",
.regex = "^trn1.*",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -106,7 +108,8 @@ static struct ec2_platform_data platform_data_map[] = {
.domain_per_thread = 1,
},
{
.name = "^trn2.*",
.name = "trn2",
.regex = "^trn2.*",
.topology = NULL,
.default_dup_conns = 0,
.latency = 75.0,
Expand All @@ -118,23 +121,80 @@ static struct ec2_platform_data platform_data_map[] = {
};


struct ec2_platform_data *platform_aws_get_platform_map(size_t *len)
{
*len = sizeof(platform_data_map)/sizeof(platform_data_map[0]);
return platform_data_map;
}


/*
* internal function (exported for unit test purposes) for finding the
* correct platform data entry. You should use
* platform_Aws_get_platform_data() so that you get caching and all
* that niceness.
*/
struct ec2_platform_data *platform_aws_get_platform_entry(const char *platform_type,
struct ec2_platform_data *platform_data_list,
size_t platform_data_len)
{
struct ec2_platform_data *response = NULL;
regex_t regex;
int ret;

for (size_t idx = 0; idx < platform_data_len; idx++) {
if (platform_data_list[idx].regex == NULL) {
if (0 == strcmp(platform_type,
platform_data_list[idx].name)) {
response = &platform_data_list[idx];
break;
}
} else {
ret = regcomp(&regex, platform_data_list[idx].regex, REG_EXTENDED);
if (ret != 0) {
NCCL_OFI_WARN("Could not compile platform_type regex for %s",
platform_data_list[idx].regex);
goto done;
}

ret = regexec(&regex, platform_type, 0, NULL, 0);

regfree(&regex);

if (ret == 0) {
response = &platform_data_list[idx];
break;
} else if (ret != REG_NOMATCH) {
NCCL_OFI_WARN("Regex match failed");
goto done;
}
}
}

done:
NCCL_OFI_TRACE(NCCL_NET | NCCL_INIT, "Using platform block %s for instance type %s",
(response == NULL) ? "none" : response->name, platform_type);

return response;
}


/*
* @brief Returns platform data for current platform type, if found
*
* @input Platform type
* @input none
*
* @return NULL, if no topology found
* @return NULL, if no entry found
* platform data, if match found
*/
static struct ec2_platform_data *get_platform_data(void)
{
static bool init = false;
static pthread_mutex_t mutex = PTHREAD_MUTEX_INITIALIZER;
static struct ec2_platform_data *platform_data = NULL;
const size_t platform_n = sizeof(platform_data_map)/sizeof(platform_data_map[0]);
const char* platform_type = NULL;
regex_t regex;
int ret;
struct ec2_platform_data *platform_data_list;
size_t platform_data_len;

nccl_net_ofi_mutex_lock(&mutex);

Expand All @@ -148,29 +208,14 @@ static struct ec2_platform_data *get_platform_data(void)
goto done;
}

for (size_t idx = 0; idx < platform_n; idx++) {
ret = regcomp(&regex, platform_data_map[idx].name, 0);
if (ret != 0) {
NCCL_OFI_WARN("Could not compile platform_type regex for %s",
platform_data_map[idx].name);
goto done;
}

ret = regexec(&regex, platform_type, 0, NULL, 0);

regfree(&regex);

if (ret == 0) {
platform_data = &platform_data_map[idx];
break;
} else if (ret != REG_NOMATCH) {
NCCL_OFI_WARN("Regex match failed");
goto done;
}
platform_data_list = platform_aws_get_platform_map(&platform_data_len);
if (platform_data_list == NULL) {
goto done;
}

NCCL_OFI_TRACE(NCCL_NET | NCCL_INIT, "Using platform block %s for instance type %s",
(platform_data == NULL) ? "none" : platform_data->name, platform_type);
platform_data = platform_aws_get_platform_entry(platform_type, platform_data_list,
platform_data_len);

done:
nccl_net_ofi_mutex_unlock(&mutex);

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ noinst_PROGRAMS = \
ep_addr_list \
mr

if WANT_PLATFORM_AWS
noinst_PROGRAMS += aws_platform_mapper
endif

if !ENABLE_NEURON
if WANT_PLATFORM_AWS
AM_LDFLAGS = $(CUDA_LDFLAGS)
Expand All @@ -38,6 +42,7 @@ msgbuff_SOURCES = msgbuff.c
scheduler_SOURCES = scheduler.c
ep_addr_list_SOURCES = ep_addr_list.c
mr_SOURCES = mr.c
aws_platform_mapper_SOURCES = aws_platform_mapper.c

TESTS = $(noinst_PROGRAMS)
endif
Loading
Loading