Skip to content

Commit

Permalink
fix(tuner): accept context by argument
Browse files Browse the repository at this point in the history
Add shims for v1 compat, add context arguments where needed. Also
refactor model code such that it can accept the context as a argument
instead of by global reference.

cr: https://code.amazon.com/reviews/CR-118885749
  • Loading branch information
aws-nslick authored and rajachan committed Apr 3, 2024
1 parent 74a79e1 commit 0f66135
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 61 deletions.
16 changes: 7 additions & 9 deletions include/nccl_ofi_tuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,22 @@ struct nccl_ofi_tuner_model_params {
int num_rails;
};

struct nccl_ofi_tuner_context {
struct nccl_ofi_tuner_model_dims {
/* communicator size */
int num_ranks;
int num_nodes;
};

struct nccl_ofi_tuner_context {
struct nccl_ofi_tuner_model_dims dims;
struct nccl_ofi_tuner_model_params model_params;

float base_costs[NCCL_NUM_FUNCTIONS][NCCL_NUM_ALGORITHMS][NCCL_NUM_PROTOCOLS];
};

/*
* Global context, allocated at _init(). This is allocated and initialized once
* per process.
*/
extern struct nccl_ofi_tuner_context *nccl_ofi_tuner_ctx;

/* Modeling functions */
void nccl_ofi_tuner_model_costs();
float nccl_ofi_tuner_compute_cost(ncclFunc_t func, int algo, int proto, int pipe_ops, size_t size);
void nccl_ofi_tuner_model_costs(struct nccl_ofi_tuner_context *ctx);
float nccl_ofi_tuner_compute_cost(struct nccl_ofi_tuner_model_params *params, struct nccl_ofi_tuner_model_dims *dims,
ncclFunc_t func, int algo, int proto, int pipe_ops, size_t size);

#endif /* NCCL_OFI_TUNER_H_ */
18 changes: 9 additions & 9 deletions src/tuner/nccl_ofi_model.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ float nccl_ofi_tuner_compute_base_cost(ncclFunc_t func, int algo, int proto)
return nccl_base_lat[algo][proto];
}

float nccl_ofi_tuner_compute_cost(ncclFunc_t func, int algo, int proto, int pipe_ops, size_t size)
float nccl_ofi_tuner_compute_cost(struct nccl_ofi_tuner_model_params *params, struct nccl_ofi_tuner_model_dims *dims,
ncclFunc_t func, int algo, int proto, int pipe_ops, size_t size)
{
struct nccl_ofi_tuner_model_params *params = &nccl_ofi_tuner_ctx->model_params;
float cost = -1;
float latency = 0;
float bw = 0;
Expand All @@ -45,22 +45,22 @@ float nccl_ofi_tuner_compute_cost(ncclFunc_t func, int algo, int proto, int pipe
case ncclFuncAllReduce:
switch(algo) {
case NCCL_ALGO_RING:
num_steps = 2 * (nccl_ofi_tuner_ctx->num_ranks - 1);
num_internode_steps = 2 * nccl_ofi_tuner_ctx->num_nodes;
num_steps = 2 * (dims->num_ranks - 1);
num_internode_steps = 2 * dims->num_nodes;
latency = (num_internode_steps * net_lat)
+ (num_steps - num_internode_steps) * p2p_lat;
bw = params->internode_bw * params->num_rails * ofi_nccl_tuner_num_channels();
break;

case NCCL_ALGO_NVLS_TREE:
latency = 2 * (p2p_lat + (log2(nccl_ofi_tuner_ctx->num_nodes) * net_lat));
latency = 2 * (p2p_lat + (log2(dims->num_nodes) * net_lat));
bw = NCCL_OFI_MIN(params->intranode_bw, (params->internode_bw * params->num_rails) / 2)
* ofi_nccl_tuner_num_channels();
break;

case NCCL_ALGO_TREE:
latency = ((2 * ((nccl_ofi_tuner_ctx->num_ranks / nccl_ofi_tuner_ctx->num_nodes) - 1) * p2p_lat)
+ (2 * log2(nccl_ofi_tuner_ctx->num_nodes) * net_lat));
latency = ((2 * ((dims->num_ranks / dims->num_nodes) - 1) * p2p_lat)
+ (2 * log2(dims->num_nodes) * net_lat));
bw = (params->internode_bw * params->num_rails * ofi_nccl_tuner_num_channels()) / 2;
break;

Expand Down Expand Up @@ -99,14 +99,14 @@ float nccl_ofi_tuner_compute_cost(ncclFunc_t func, int algo, int proto, int pipe
* Compute the base costs for each of the algorithms at plugin initialization
* time using only the comm size.
*/
void nccl_ofi_tuner_model_costs()
void nccl_ofi_tuner_model_costs(struct nccl_ofi_tuner_context *ctx)
{
ncclFunc_t func;
int algo, proto = 0;
for (func = 0; func < NCCL_NUM_FUNCTIONS; func++) {
for (algo = 0; algo < NCCL_NUM_ALGORITHMS; algo++) {
for(proto = 0; proto < NCCL_NUM_PROTOCOLS; proto++) {
nccl_ofi_tuner_ctx->base_costs[func][algo][proto] =
ctx->base_costs[func][algo][proto] =
nccl_ofi_tuner_compute_base_cost(func, algo, proto);
}
}
Expand Down
125 changes: 82 additions & 43 deletions src/tuner/nccl_ofi_tuner.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,15 @@
#include "nccl_ofi_tuner.h"
#include "nccl_ofi_log.h"

struct nccl_ofi_tuner_context *nccl_ofi_tuner_ctx;
pthread_mutex_t nccl_ofi_tuner_ctx_lock = PTHREAD_MUTEX_INITIALIZER;
ncclDebugLogger_t ofi_log_function = NULL;

ncclResult_t nccl_ofi_tuner_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction)
ncclResult_t nccl_ofi_tuner_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction, void **context)
{
ofi_log_function = logFunction;
struct nccl_ofi_tuner_context *nccl_ofi_tuner_ctx;

/*
* NCCL parses these variables and applies user filters inside its
* current tuner logic. Ideally, this should be done regardless of the
* use of NCCL's internal tuner or an external tuner plugin. For the
* time being, given the external tuner is an opt-in, detect if a user
* has set one of them and bail when an external tuner is loaded.
*/
if (getenv("NCCL_ALGO") || getenv("NCCL_PROTO")) {
NCCL_OFI_WARN("The tuner plugin can not be loaded when explicitly choosing an algorithm or protocol with NCCL_ALGO/NCCL_PROTO");
// FIXME: "ncclInvalidUsage should be returned when the error is
// most likely a user error" per nccl docs, which arguably makes
// it a better return code here than ncclInvalidArgument, but
// the former is currently not vended in ext-net headers, so
// we're returning ncclInvalidArgument instead.
return ncclInvalidArgument;
}

struct nccl_ofi_tuner_model_params params = {
const struct nccl_ofi_tuner_model_params params = {
.net_lat = ofi_nccl_tuner_net_latency(),
.internode_bw = NCCL_OFI_TUNER_INTERNODE_BW,
.intranode_bw = NCCL_OFI_TUNER_INTRANODE_BW,
Expand All @@ -44,38 +27,38 @@ ncclResult_t nccl_ofi_tuner_init(size_t nRanks, size_t nNodes, ncclDebugLogger_t
* initialization. For now, init a plugin-lobal context once.
*/
pthread_mutex_lock(&nccl_ofi_tuner_ctx_lock);
nccl_ofi_tuner_ctx = calloc(1, sizeof(struct nccl_ofi_tuner_context));
if (!nccl_ofi_tuner_ctx) {
nccl_ofi_tuner_ctx = calloc(1, sizeof(struct nccl_ofi_tuner_context));
if (!nccl_ofi_tuner_ctx) {
NCCL_OFI_WARN("Context allocation failed.");
return ncclInternalError;
}
NCCL_OFI_WARN("Context allocation failed.");
return ncclInternalError;
}

nccl_ofi_tuner_ctx->num_ranks = nRanks;
nccl_ofi_tuner_ctx->num_nodes = nNodes;
nccl_ofi_tuner_ctx->model_params = params;
nccl_ofi_tuner_ctx->dims.num_ranks = nRanks;
nccl_ofi_tuner_ctx->dims.num_nodes = nNodes;
nccl_ofi_tuner_ctx->model_params = params;

/*
* Build cost model to use from nccl_ofi_tuner_get_coll_info.
*/
nccl_ofi_tuner_model_costs();
}
/*
* Build cost model to use from nccl_ofi_tuner_get_coll_info.
*/
nccl_ofi_tuner_model_costs(nccl_ofi_tuner_ctx);
*context = (void*)nccl_ofi_tuner_ctx;
pthread_mutex_unlock(&nccl_ofi_tuner_ctx_lock);

NCCL_OFI_TRACE(NCCL_TUNING, "Tuner init: comm with %ld ranks and %ld nodes.", nRanks, nNodes);
return ncclSuccess;
}

ncclResult_t nccl_ofi_tuner_get_coll_info(ncclFunc_t collType, size_t nBytes,
ncclResult_t nccl_ofi_tuner_get_coll_info(void *context, ncclFunc_t collType, size_t nBytes,
int collNetSupport, int nvlsSupport, int numPipeOps,
int *algorithm, int *protocol, int* nChannels)
{
float cost = 0;
float lowest = FLT_MAX;
int algo, proto = 0;
struct nccl_ofi_tuner_context *nccl_ofi_tuner_ctx = (struct nccl_ofi_tuner_context *)context;

/* Skip runs smaller than 2 nodes and fallback to NCCL's internal tunings */
if (nccl_ofi_tuner_ctx->num_nodes <= 2)
if (nccl_ofi_tuner_ctx->dims.num_nodes <= 2)
return ncclSuccess;

/*
Expand All @@ -100,7 +83,8 @@ ncclResult_t nccl_ofi_tuner_get_coll_info(ncclFunc_t collType, size_t nBytes,
if (algo == NCCL_ALGO_NVLS_TREE && proto != NCCL_PROTO_SIMPLE)
continue;

cost = nccl_ofi_tuner_compute_cost(collType, algo, proto, numPipeOps, nBytes);
cost = nccl_ofi_tuner_compute_cost(&nccl_ofi_tuner_ctx->model_params, &nccl_ofi_tuner_ctx->dims,
collType, algo, proto, numPipeOps, nBytes);
if (cost < 0)
continue;

Expand All @@ -118,21 +102,76 @@ ncclResult_t nccl_ofi_tuner_get_coll_info(ncclFunc_t collType, size_t nBytes,
return ncclSuccess;
}

ncclResult_t nccl_ofi_tuner_destroy()
ncclResult_t nccl_ofi_tuner_destroy(void *context)
{
pthread_mutex_lock(&nccl_ofi_tuner_ctx_lock);
free(nccl_ofi_tuner_ctx);
/* Prevent other threads from freeing a dangling global ctx */
nccl_ofi_tuner_ctx = NULL;
if (context != NULL) {
free(context);
}
pthread_mutex_unlock(&nccl_ofi_tuner_ctx_lock);

return ncclSuccess;
}

const ncclTuner_v2_t ncclTunerPlugin_v2 = {
.name = "nccl_ofi_tuner",
.init = nccl_ofi_tuner_init,
.getCollInfo = nccl_ofi_tuner_get_coll_info,
.destroy = nccl_ofi_tuner_destroy
};

#if !defined(AWS_OFI_NCCL_MIN_TUNER_COMPAT) || (AWS_OFI_NCCL_MIN_TUNER_COMPAT <= 1)
static struct nccl_ofi_tuner_context *nccl_ofi_tuner_ctx_internal;

static ncclResult_t nccl_ofi_tuner_destroy_v1(void)
{
void *context = NULL;

pthread_mutex_lock(&nccl_ofi_tuner_ctx_lock);
if (nccl_ofi_tuner_ctx_internal != NULL) {
/* Prevent other threads from freeing a dangling global ctx */
context = (void*)nccl_ofi_tuner_ctx_internal;
nccl_ofi_tuner_ctx_internal = NULL;
}
pthread_mutex_unlock(&nccl_ofi_tuner_ctx_lock);

return nccl_ofi_tuner_destroy(context);
}

static ncclResult_t nccl_ofi_tuner_init_v1(size_t nRanks, size_t nNodes, ncclDebugLogger_t logFunction)
{
/*
* NCCL parses these variables and applies user filters inside its
* current tuner logic. Ideally, this should be done regardless of the
* use of NCCL's internal tuner or an external tuner plugin. For the
* time being, given the external tuner is an opt-in, detect if a user
* has set one of them and bail when an external tuner is loaded.
*/
if (getenv("NCCL_ALGO") || getenv("NCCL_PROTO")) {
NCCL_OFI_WARN("The tuner plugin can not be loaded when explicitly choosing an algorithm or protocol with NCCL_ALGO/NCCL_PROTO");
// FIXME: "ncclInvalidUsage should be returned when the error is
// most likely a user error" per nccl docs, which arguably makes
// it a better return code here than ncclInvalidArgument, but
// the former is currently not vended in ext-net headers, so
// we're returning ncclInvalidArgument instead.
return ncclInvalidArgument;
}
return nccl_ofi_tuner_init(nRanks, nNodes, logFunction, (void**)&nccl_ofi_tuner_ctx_internal);
}

static ncclResult_t nccl_ofi_tuner_get_coll_info_v1(ncclFunc_t collType, size_t nBytes, int collNetSupport,
int nvlsSupport, int numPipeOps, int *algorithm, int *protocol,
int *nChannels)
{
return nccl_ofi_tuner_get_coll_info(&nccl_ofi_tuner_ctx_internal, collType, nBytes,
collNetSupport, nvlsSupport, numPipeOps, algorithm,
protocol, nChannels);
}

const ncclTuner_v1_t ncclTunerPlugin_v1 = {
.name = "nccl_ofi_tuner",
.init = nccl_ofi_tuner_init,
.getCollInfo = nccl_ofi_tuner_get_coll_info,
.destroy = nccl_ofi_tuner_destroy
.init = nccl_ofi_tuner_init_v1,
.getCollInfo = nccl_ofi_tuner_get_coll_info_v1,
.destroy = nccl_ofi_tuner_destroy_v1
};
#endif /* !defined(AWS_OFI_NCCL_MIN_TUNER_COMPAT) || (AWS_OFI_NCCL_MIN_TUNER_COMPAT <= 1) */

0 comments on commit 0f66135

Please sign in to comment.