diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index d2b13d215..106819b43 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -18,6 +18,44 @@ extern "C" { #include "nccl_ofi_log.h" #include "nccl_ofi_pthread.h" +#define OFI_NCCL_PARAM_UINT(name, env, default_value) \ + static pthread_mutex_t ofi_nccl_param_lock_##name = PTHREAD_MUTEX_INITIALIZER; \ + static inline uint64_t ofi_nccl_##name() \ + { \ + static bool initialized = false; \ + static uint64_t value = default_value; \ + if (initialized) { \ + return value; \ + } \ + nccl_net_ofi_mutex_lock(&ofi_nccl_param_lock_##name); \ + uint64_t v; \ + char *str, *endptr; \ + if (!initialized) { \ + str = getenv("OFI_NCCL_" env); \ + if (str && strlen(str) > 0) { \ + errno = 0; \ + v = strtoull(str, &endptr, 0); \ + if (errno || str == endptr || *endptr != '\0') { \ + NCCL_OFI_INFO( \ + NCCL_INIT | NCCL_NET, \ + "Invalid value %s provided for %s environment variable, using default %lu", \ + str, \ + "OFI_NCCL_" env, \ + value); \ + } else { \ + value = v; \ + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, \ + "Setting %s environment variable to %lu", \ + "OFI_NCCL_" env, \ + value); \ + } \ + } \ + initialized = true; \ + } \ + nccl_net_ofi_mutex_unlock(&ofi_nccl_param_lock_##name); \ + return value; \ + } + #define OFI_NCCL_PARAM_INT(name, env, default_value) \ static pthread_mutex_t ofi_nccl_param_lock_##name = PTHREAD_MUTEX_INITIALIZER; \ static inline int64_t ofi_nccl_##name() { \ @@ -124,7 +162,7 @@ OFI_NCCL_PARAM_INT(cuda_flush_enable, "CUDA_FLUSH_ENABLE", 0); * Specify the memory registration key size in bytes when using a libfabric * provider that supports application-selected memory registration keys. */ -OFI_NCCL_PARAM_INT(mr_key_size, "MR_KEY_SIZE", 2); +OFI_NCCL_PARAM_UINT(mr_key_size, "MR_KEY_SIZE", 2); /* * Disable the MR cache. The MR cache is used to keep track of registered @@ -198,7 +236,7 @@ OFI_NCCL_PARAM_INT(disable_dmabuf, "DISABLE_DMABUF", 0); /* * Maximum size of a message in bytes before message is multiplexed */ -OFI_NCCL_PARAM_INT(round_robin_threshold, "ROUND_ROBIN_THRESHOLD", (256 * 1024)); +OFI_NCCL_PARAM_UINT(round_robin_threshold, "ROUND_ROBIN_THRESHOLD", (256 * 1024)); /* * Minimum bounce buffers posted per endpoint. The plugin will attempt to post @@ -229,11 +267,12 @@ OFI_NCCL_PARAM_INT(net_latency, "NET_LATENCY", -1); * tweak defaults from the platform file, but this fits our needs for * now. */ -OFI_NCCL_PARAM_INT(eager_max_size, "EAGER_MAX_SIZE", +OFI_NCCL_PARAM_UINT(eager_max_size, + "EAGER_MAX_SIZE", #if HAVE_NEURON - 0 + 0 #else - 8192 + 8192 #endif );