diff --git a/src/xccl/ProcessGroupXCCL.cpp b/src/xccl/ProcessGroupXCCL.cpp index 54db563c1..58107ff76 100644 --- a/src/xccl/ProcessGroupXCCL.cpp +++ b/src/xccl/ProcessGroupXCCL.cpp @@ -31,22 +31,69 @@ const std::map xcclDatatypes = { {at::kFloat8_e5m2fnuz, ccl::datatype::uint8}, }; -void checkXPUTensor(at::Tensor& tensor) { +bool checkSameSize(const std::vector& input_tensors) { + for (const auto& input_tensor : input_tensors) { + if (!input_tensors[0].is_same_size(input_tensor)) { + return false; + } + } + return true; +} + +void checkSingleTensor( + const at::Tensor& tensor, + const bool p2p = false // whether operation is a P2P operation +) { if (!tensor.is_xpu() || tensor.is_sparse() || tensor.is_complex()) { C10_THROW_ERROR( ValueError, "Tensors must be XPU and dense and non-complex"); + + // Skip the following requirements for P2P operations if (!tensor.is_contiguous(tensor.suggest_memory_format())) { - C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + if (p2p) { + TORCH_WARN_ONCE( + "Detected non-contiguous tensor in P2P operations. It is user " + "responsibility to guarantee that source and destination tensors have " + "the same contiguity format."); + } else { + C10_THROW_ERROR(ValueError, "Tensors must be contiguous"); + } } } } +int64_t checkTensorOnSameDevice(const std::vector& tensors) { + TORCH_CHECK_WITH( + ValueError, tensors.size() != 0, "Tensor list must be nonempty"); + + const auto& first = tensors.front(); + + int64_t total_numel = 0; + for (const auto& t : tensors) { + if (!t.is_xpu() || t.is_sparse() || t.is_complex()) { + C10_THROW_ERROR( + ValueError, "Tensors must be XPU and dense and non-complex"); + } + if (t.scalar_type() != first.scalar_type()) { + C10_THROW_ERROR(TypeError, "Tensors must have identical type"); + } + TORCH_CHECK_WITH( + ValueError, + t.get_device() == tensors[0].get_device(), + "Expected list of tensors on the same device"); + total_numel += t.numel(); + } + + return total_numel; +} + ccl::datatype getXcclDataType( at::ScalarType type, bool is_reduction_op = false) { - TORCH_CHECK( - !isFloat8Type(type) && is_reduction_op, - "Float8 dtypes are not currenlty supported for XCCL reductions"); + if (is_reduction_op) + TORCH_CHECK( + !isFloat8Type(type), + "Float8 dtypes are not currenlty supported for XCCL reductions"); auto it = xcclDatatypes.find(type); TORCH_CHECK_WITH( TypeError, @@ -62,6 +109,11 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { // Map sum to max for bool tensors to avoid overflow issues with sum. return ccl::reduction::max; } + // Use SUM emu AVG due to oneCCL not support AVG. + // oneCCL is expected to support avg in basekit 2025.2 release. + if (reduceOp == ReduceOp::AVG) { + return ccl::reduction::sum; + } return xcclOps.at(reduceOp); } catch (const std::out_of_range&) { C10_THROW_ERROR( @@ -77,9 +129,11 @@ void syncStream( xcclEvent.record(at::xpu::getCurrentXPUStream(device.index())); xcclEvent.block(xcclStream); } + } // namespace constexpr int64_t kSynchronizeBusyWaitMillis = 10; +thread_local uint64_t ProcessGroupXCCL::xcclActiveGroupCounter_ = 0; ProcessGroupXCCL::WorkXCCL::WorkXCCL( at::Device& device, @@ -134,6 +188,10 @@ void ProcessGroupXCCL::WorkXCCL::synchronizeInternal( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } } + if (barrierTensor_.defined()) { + auto currentStream = at::xpu::getCurrentXPUStream(device_.index()); + currentStream.synchronize(); + } } bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { @@ -141,6 +199,9 @@ bool ProcessGroupXCCL::WorkXCCL::wait(std::chrono::milliseconds timeout) { return true; } +constexpr const char* MULTI_DEVICE_ERROR_MSG = + "Expecting one tensor only but got multiple"; + ProcessGroupXCCL::ProcessGroupXCCL( const c10::intrusive_ptr& store, int rank, @@ -152,6 +213,12 @@ ProcessGroupXCCL::ProcessGroupXCCL( ProcessGroupXCCL::~ProcessGroupXCCL() = default; +void ProcessGroupXCCL::setSequenceNumberForGroup() {} + +uint64_t ProcessGroupXCCL::getSequenceNumberForGroup() { + return seqCollective_; +} + c10::intrusive_ptr ProcessGroupXCCL::initWork( at::Device& device, int rank, @@ -171,12 +238,19 @@ c10::intrusive_ptr ProcessGroupXCCL::initWork( std::shared_ptr ProcessGroupXCCL::getXCCLComm( const std::string& deviceKey, - at::Device& device) { - TORCH_CHECK_WITH( - DistBackendError, - !deviceKey.empty(), - "Not able to create/get " - "XCCL Communicator since the devices are empty "); + at::Device& device, + OpType opType, + int p2pRank, + bool isSendRecvSelf) { + if (deviceKey.empty()) { + C10_THROW_ERROR( + DistBackendError, + "Not able to create/get the XCCL Communicator since " + "the devices are empty "); + } + + usedDeviceIdxs_.insert(device.index()); + { std::lock_guard lock(mutex_); if (devXCCLCommMap_.find(deviceKey) != devXCCLCommMap_.end()) { @@ -184,9 +258,24 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( } } + std::shared_ptr XCCLComm; + + bool batchP2P = xcclActiveGroupCounter_ > 0; + bool singleP2POp = isP2POp(opType, batchP2P); + + at::xpu::OptionalXPUGuard gpuGuard(device); + int numRanks, rank; - numRanks = getSize(); - rank = getRank(); + if (!singleP2POp) { + numRanks = getSize(); + rank = getRank(); + } else if (isSendRecvSelf) { + numRanks = 1; + rank = 0; + } else { + numRanks = 2; + rank = p2pRank; + } c10::impl::VirtualGuardImpl impl(device.type()); c10::Stream stream = @@ -197,10 +286,23 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( ccl::vector_class> devs_rank; devs_rank.emplace_back(rank, ccl::create_device(q.get_device())); - auto xccl_kvs = get_kvs(rank_, *store_); + auto xccl_kvs = get_kvs(rank_, *store_, singleP2POp, deviceKey, p2pRank); auto comms = ccl::create_communicators(numRanks, devs_rank, ctx, xccl_kvs); - std::shared_ptr XCCLComm = - std::make_shared(std::move(comms[0])); + XCCLComm = std::make_shared(std::move(comms[0])); + + RECORD_PARAM_COMMS( + 0, // seq + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank, // rank + "init", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize std::lock_guard lock(mutex_); devXCCLCommMap_.emplace(deviceKey, XCCLComm); @@ -210,6 +312,63 @@ std::shared_ptr ProcessGroupXCCL::getXCCLComm( return XCCLComm; } +void ProcessGroupXCCL::groupStart() { + ccl::group_start(); + ++xcclActiveGroupCounter_; +} + +void ProcessGroupXCCL::groupEnd() { + ccl::group_end(); + --xcclActiveGroupCounter_; +} + +static constexpr int CoalActive = 0x01, CoalColl = 0x02, CoalP2P = 0x04; +void ProcessGroupXCCL::startCoalescing() { + if (coalescing_state_ & CoalP2P) { + seqP2P_++; + } else { + seqCollective_++; + } + coalescedDevice_.set_index(-1); + coalescedComm_ = nullptr; + coalescing_state_ |= CoalActive; + groupStart(); +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing(OpType optype) { + if (coalescedComm_ == nullptr) { + // There is no actual work being coalesced, return here + groupEnd(); + coalescing_state_ = 0; + return nullptr; + } + TORCH_CHECK( + coalescedDevice_.index() >= 0, + "Somthing went wrong. Did you call end_coalescing before start_coalescing?"); + + auto comm = coalescedComm_; + auto device = coalescedDevice_; + + const auto key = std::to_string(device.index()); + auto stream = xcclStreamsMap_.at(key); + + auto work = initWork(device, rank_, optype); + work->blockingWait_ = blockingWait_; + + groupEnd(); + + work->xcclEndEvent_->record(stream); + + coalescing_state_ = 0; + coalescedComm_ = nullptr; + return work; +} + +c10::intrusive_ptr ProcessGroupXCCL::endCoalescing() { + // Default OpType to COALESCED if not specified + return endCoalescing(OpType::COALESCED); +} + template c10::intrusive_ptr ProcessGroupXCCL::collective( std::vector& inputs, @@ -220,28 +379,49 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( OpType opType, const char* profilingTitle) { seqCollective_++; - auto device = inputs[0].device(); const auto key = std::to_string(device.index()); - auto comm = getXCCLComm(key, device); + auto comm = getXCCLComm(key, device, opType); + + if (coalescing_state_ & CoalActive) { + coalescing_state_ |= CoalColl; + if (coalescedDevice_.index() < 0) { + coalescedDevice_ = device; + } else { + TORCH_CHECK( + coalescedDevice_.index() == device.index(), MULTI_DEVICE_ERROR_MSG); + } + if (coalescedComm_ == nullptr) { + coalescedComm_ = comm; + } else { + TORCH_CHECK(coalescedComm_ == comm, MULTI_DEVICE_ERROR_MSG); + } + } auto stream = xcclStreamsMap_.at(key); syncStream(device, xcclEventsMap_[key], stream); c10::intrusive_ptr work; - work = initWork(device, rank_, opType, profilingTitle); + work = initWork(device, rank_, opType); + work->outputs_ = std::make_shared>(outputs); at::xpu::OptionalXPUGuard gpuGuard(device); + pre(stream, work); + for (const auto i : c10::irange(inputs.size())) { c10::xpu::XPUCachingAllocator::recordStream( inputs[i].storage().data_ptr(), stream); fn(inputs[i], outputs[i], *comm, stream); } + post(stream, work); - work->xcclEndEvent_->record(stream); + if (!coalescing_state_) { + work->xcclEndEvent_->record(stream); + } + std::vector streams = {stream.unwrap()}; c10::MultiStreamGuard streamGuard(streams); std::vector devices{device}; @@ -253,13 +433,63 @@ c10::intrusive_ptr ProcessGroupXCCL::collective( return work; } +c10::intrusive_ptr ProcessGroupXCCL::allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts) { + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + auto ccl_stream = ccl::create_stream(stream.queue()); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::ALLREDUCE, + "xccl:all_reduce"); +} + c10::intrusive_ptr ProcessGroupXCCL::allreduce( std::vector& tensors, const AllreduceOptions& opts) { - TORCH_CHECK( - tensors.size() == 1, "Expecting one tensor only but got multiple"); + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); auto tensor = tensors.back(); - checkXPUTensor(tensor); + checkSingleTensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + size_); // worldSize return collective( tensor, @@ -270,7 +500,6 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( at::xpu::XPUStream& stream) { auto xcclDataType = getXcclDataType(input.scalar_type(), true); auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); - auto ccl_stream = ccl::create_stream(stream.queue()); ccl::allreduce( input.data_ptr(), output.data_ptr(), @@ -278,13 +507,587 @@ c10::intrusive_ptr ProcessGroupXCCL::allreduce( xcclDataType, xcclReduceOp, comm, - ccl_stream); + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } return; }, OpType::ALLREDUCE, "xccl:all_reduce"); } +c10::intrusive_ptr ProcessGroupXCCL::allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts) { + auto total_numel = checkTensorOnSameDevice(tensors); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + rank_, // rank + "allreduce_coalesced", // collective name + total_numel, // inNelems + total_numel, // outNelems + tensors[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collectiveCoalesced( + tensors, + tensors, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::allreduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::COALESCED, + "xccl:allreduce_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::broadcast( + std::vector& tensors, + const BroadcastOptions& opts) { + TORCH_CHECK(tensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + auto tensor = tensors.back(); + checkSingleTensor(tensor); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + tensors, // inputTensors + tensors, // outputTensors + opts.rootRank, // root rank + "broadcast", // collective name + tensor.numel(), // inNelems + tensor.numel(), // outNelems + tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + const auto root = opts.rootRank + opts.rootTensor; + + return collective( + tensor, + tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::broadcast( + input.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::BROADCAST, + "nccl:broadcast"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts) { + if (outputTensor.numel() != inputTensor.numel()) { + C10_THROW_ERROR( + ValueError, + "Tensor input and output of _broadcast_oop must have the same number of elements "); + } + const auto root = opts.rootRank + opts.rootTensor; + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::broadcast( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + root, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::BROADCAST, + "xccl:_broadcast_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceOptions& opts) { + TORCH_CHECK_WITH( + ValueError, + outputTensor.numel() == inputTensor.numel(), + "Tensor input and output of _reduce_oop must have the same number of elements"); + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + const int root = opts.rootRank + opts.rootTensor; + const auto xcclDataType = getXcclDataType(input.scalar_type(), true); + const auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + xcclReduceOp, + root, + comm, + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG && getRank() == root) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::REDUCE, + "xccl:_reduce_oop"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts) { + TORCH_CHECK(inputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto inputTensor = inputTensors.back(); + checkSingleTensor(inputTensor); + // @lint-ignore CLANGTIDY + std::vector& outputTensors_ = outputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "all_gather", // collective name + inputTensor.numel(), // inNelems + inputTensor.numel() * // outNelems + this->getSize(), + inputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = checkSameSize(outputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor outputFlattened = newLikeFlat(outputTensors_); + + return collective( + inputTensor, + outputFlattened, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr& work) {}, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the flattened output tensors to the outputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(outputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + outputTensors_[j].storage().data_ptr(), Stream); + outputTensors_[j].copy_(outputFlattened[j], true); + } + }, + OpType::ALLGATHER, + "xccl:all_gather"); + } else { + const auto num_reduces = outputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& output = outputTensors_[i]; + auto& input = (i == rank_) ? inputTensor : output; + auto broadcastOpts = BroadcastOptions{ + static_cast(i), static_cast(0), opts.timeout}; + _broadcast_oop(output, input, broadcastOpts); + } + auto work = endCoalescing(OpType::ALLGATHER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_allgather_base( + at::Tensor& output_tensor, + at::Tensor& input_tensor, + const AllgatherOptions& opts) { + checkSingleTensor(input_tensor); + checkSingleTensor(output_tensor); + + TORCH_CHECK_WITH( + TypeError, + input_tensor.dtype() == output_tensor.dtype(), + "output tensor must have the same type as input tensor"); + TORCH_CHECK_WITH( + ValueError, + input_tensor.numel() * size_ == output_tensor.numel(), + "output tensor size must be equal to world_size times input tensor size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + input_tensor, // inputTensors + output_tensor, // outputTensors + rank_, // rank + "_allgather_base", // collective name + input_tensor.numel(), // inNelems + output_tensor.numel(), // outNelems + output_tensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSize + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + input_tensor, + output_tensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::_ALLGATHER_BASE, + "xccl:_all_gather_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + auto xcclDataType = getXcclDataType(input.scalar_type()); + ccl::allgather( + input.data_ptr(), + output.data_ptr(), + (size_t)input.numel(), + xcclDataType, + comm, + ccl::create_stream(stream.queue())); + return; + }, + OpType::COALESCED, + "xccl:all_gather_into_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts) { + TORCH_CHECK(outputTensors.size() == 1, MULTI_DEVICE_ERROR_MSG); + // @lint-ignore CLANGTIDY + auto outputTensor = outputTensors.back(); + checkSingleTensor(outputTensor); + // @lint-ignore CLANGTIDY + auto inputTensors_ = inputTensors.back(); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensors, // inputTensors + outputTensors, // outputTensors + rank_, // rank + "reduce_scatter", // collective name + outputTensor.numel() * this->getSize(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + bool same_size = checkSameSize(inputTensors_); + if (same_size) { + // Flatten a vector of tensors into a single, stacked tensor. + at::Tensor inputFlattened = newLikeFlat(inputTensors_); + return collective( + inputFlattened, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + [&](at::xpu::XPUStream& Stream, + c10::intrusive_ptr& work) { + // Copy the input tensors to the flattened inputs. + c10::StreamGuard guard(Stream); + for (const auto j : c10::irange(inputTensors_.size())) { + c10::xpu::XPUCachingAllocator::recordStream( + inputTensors_[j].storage().data_ptr(), Stream); + inputFlattened[j].copy_(inputTensors_[j], true); + } + }, + [&](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + OpType::REDUCE_SCATTER, + "xccl:reduce_scatter"); + } else { + const auto num_reduces = inputTensors_.size(); + startCoalescing(); + for (const int i : c10::irange(num_reduces)) { + auto& input = inputTensors_[i]; + auto& output = (i == rank_) ? outputTensor : input; + auto reduceOpts = ReduceOptions{ + opts.reduceOp, + static_cast(i), + static_cast(0), + opts.timeout}; + _reduce_oop(output, input, reduceOpts); + } + auto work = endCoalescing(OpType::REDUCE_SCATTER); + return work; + } +} + +c10::intrusive_ptr ProcessGroupXCCL::_reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts) { + TORCH_CHECK_WITH( + TypeError, + inputTensor.dtype() == outputTensor.dtype(), + "input tensor must be the same type as the output tensor."); + TORCH_CHECK_WITH( + ValueError, + inputTensor.numel() == outputTensor.numel() * size_, + "input tensor must be the same size as output size times world size"); + + RECORD_PARAM_COMMS_DATA( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputTensor, // inputTensor + outputTensor, // outputTensor + rank_, // rank + "_reduce_scatter_base", // collective name + inputTensor.numel(), // inNelems + outputTensor.numel(), // outNelems + outputTensor.scalar_type(), // dtype + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + + return collective( + inputTensor, + outputTensor, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::_REDUCE_SCATTER_BASE, + "xccl:_reduce_scatter_base"); +} + +c10::intrusive_ptr ProcessGroupXCCL::reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts) { + return collectiveCoalesced( + inputs, + outputs, + [&](at::Tensor& input, + at::Tensor& output, + xcclComm_t& comm, + at::xpu::XPUStream& stream) { + c10::xpu::XPUCachingAllocator::recordStream( + output.storage().data_ptr(), stream); + auto xcclDataType = getXcclDataType(input.scalar_type(), true); + auto xcclReduceOp = getXcclReduceOp(opts.reduceOp, input); + ccl::reduce_scatter( + input.data_ptr(), + output.data_ptr(), + (size_t)output.numel(), + xcclDataType, + xcclReduceOp, + comm, + ccl::create_stream(stream.queue())); + // Use SUM emu AVG due to oneCCL not support AVG + // oneCCL is expected to support avg in basekit 2025.2 release. + if (opts.reduceOp == ReduceOp::AVG) { + auto divisor = getSize(); + output.div_(divisor); + } + return; + }, + OpType::COALESCED, + "xccl:reduce_scatter_tensor_coalesced"); +} + +c10::intrusive_ptr ProcessGroupXCCL::barrier(const BarrierOptions& opts) { + RECORD_PARAM_COMMS( + static_cast( + this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + rank_, // rank + "barrier", // collective name + 0, // inNelems + 0, // outNelems + at::kByte, // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + -1, // globalRankStart + -1, // globalRankStride + this->getSize()); // worldSize + // Device to use for barrier + int barDevIdx = -1; + + // See nccl barrier comments + if (!opts.device_ids.empty()) { + barDevIdx = opts.device_ids[0]; + } else if (getBoundDeviceId()) { + barDevIdx = (*getBoundDeviceId()).index(); + } else if (!usedDeviceIdxs_.empty()) { + barDevIdx = *usedDeviceIdxs_.begin(); + } else { + barDevIdx = + static_cast(rank_ % at::detail::getXPUHooks().getNumGPUs()); + } + + TORCH_CHECK_WITH( + ValueError, + barDevIdx >= 0, + "Failed to infer a GPU device id to perform barrier. "); + auto barDevice = at::Device(at::DeviceType::XPU, barDevIdx); + + at::Tensor barrierTensor = + at::zeros({1}, at::TensorOptions().device(barDevice).dtype(at::kFloat)); + + auto work = allreduce_impl(barrierTensor); + + auto xcclWork = dynamic_cast(work.get()); + TORCH_CHECK(xcclWork); + xcclWork->barrierTensor_ = std::move(barrierTensor); + return work; +} + } // namespace c10d #endif // USE_C10D_XCCL diff --git a/src/xccl/ProcessGroupXCCL.hpp b/src/xccl/ProcessGroupXCCL.hpp index 21269bd6f..0a80b4b17 100644 --- a/src/xccl/ProcessGroupXCCL.hpp +++ b/src/xccl/ProcessGroupXCCL.hpp @@ -67,6 +67,7 @@ class TORCH_API ProcessGroupXCCL : public Backend { protected: at::Device device_; std::shared_ptr xcclEndEvent_; + at::Tensor barrierTensor_; bool blockingWait_ = false; std::chrono::time_point workStartTime_; uint64_t seq_; @@ -93,9 +94,18 @@ class TORCH_API ProcessGroupXCCL : public Backend { return std::string(XCCL_BACKEND_NAME); } + void startCoalescing() override; + + c10::intrusive_ptr endCoalescing() override; + + c10::intrusive_ptr endCoalescing(OpType optype); + std::shared_ptr getXCCLComm( const std::string& deviceKey, - at::Device& device); + at::Device& device, + OpType opType, + int p2pRank = 0, + bool isSendRecvSelf = false); virtual c10::intrusive_ptr initWork( at::Device& device, @@ -112,8 +122,39 @@ class TORCH_API ProcessGroupXCCL : public Backend { Fn fn, OpType opType, const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) {}, + opType, + profilingTitle); + } + + template + c10::intrusive_ptr collective( + at::Tensor& input, + at::Tensor& output, + Fn fn, + PreProcess pre, + PostProcess post, + OpType opType, + const char* profilingTitle = nullptr) { auto inputs = std::vector{input}; auto outputs = std::vector{output}; + return collective(inputs, outputs, fn, pre, post, opType, profilingTitle); + } + + template + c10::intrusive_ptr collective( + std::vector& inputs, + std::vector& outputs, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { return collective( inputs, outputs, @@ -122,7 +163,8 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr&) {}, [](at::xpu::XPUStream&, c10::intrusive_ptr&) {}, - opType); + opType, + profilingTitle); } template @@ -135,14 +177,106 @@ class TORCH_API ProcessGroupXCCL : public Backend { OpType opType, const char* profilingTitle = nullptr); + template + c10::intrusive_ptr collectiveCoalesced( + std::vector& input, + std::vector& output, + Fn fn, + OpType opType, + const char* profilingTitle = nullptr) { + return collective( + input, + output, + fn, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + // There are two types of coalesce that require `group_start/end`: + // 1. **Fast Pass for Operations**: For example, + // `allreduce_coalesced`. In this case, the backend has control, so + // the initial group API `ccl::group` is called. + // 2. **User-Specified Groups**: The user specifies a series of + // operations as a group in the frontend by calling the coalesce + // manager. To avoid incorrect judgments of the p2p state, the + // `xcclActiveGroupCounter_` is introduced to track group calls made + // in the frontend. In this scenario, the `groupStart` wrap API is + // used. + ccl::group_start(); + }, + [](at::xpu::XPUStream&, + c10::intrusive_ptr&) { + ccl::group_end(); + }, + opType, + profilingTitle); + } + + c10::intrusive_ptr allreduce_impl( + at::Tensor& tensor, + const AllreduceOptions& opts = AllreduceOptions()); + c10::intrusive_ptr allreduce( std::vector& tensors, const AllreduceOptions& opts = AllreduceOptions()) override; - void setSequenceNumberForGroup() override {} - uint64_t getSequenceNumberForGroup() override { - return seqCollective_; - } + c10::intrusive_ptr allreduce_coalesced( + std::vector& tensors, + const AllreduceCoalescedOptions& opts = + AllreduceCoalescedOptions()) override; + + c10::intrusive_ptr _reduce_oop( + at::Tensor& outputTensors, + at::Tensor& inputTensors, + const ReduceOptions& opts = ReduceOptions()); + + c10::intrusive_ptr broadcast( + std::vector& tensors, + const BroadcastOptions& opts = BroadcastOptions()) override; + + c10::intrusive_ptr _broadcast_oop( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const BroadcastOptions& opts); + + c10::intrusive_ptr allgather( + std::vector>& outputTensors, + std::vector& inputTensors, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr _allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr allgather_into_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const AllgatherOptions& opts = AllgatherOptions()) override; + + c10::intrusive_ptr reduce_scatter( + std::vector& outputTensors, + std::vector>& inputTensors, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr _reduce_scatter_base( + at::Tensor& outputTensor, + at::Tensor& inputTensor, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr reduce_scatter_tensor_coalesced( + std::vector& outputs, + std::vector& inputs, + const ReduceScatterOptions& opts = ReduceScatterOptions()) override; + + c10::intrusive_ptr barrier( + const BarrierOptions& opts = BarrierOptions()) override; + + void groupStart(); + + void groupEnd(); + + void setSequenceNumberForGroup() override; + + uint64_t getSequenceNumberForGroup() override; protected: std::unordered_map xcclStreamsMap_; @@ -151,8 +285,14 @@ class TORCH_API ProcessGroupXCCL : public Backend { c10::intrusive_ptr store_; uint64_t xcclCommCounter_{0}; std::mutex mutex_; + std::set usedDeviceIdxs_; + int coalescing_state_ = 0; + at::Device coalescedDevice_ = at::Device("xpu"); + std::shared_ptr coalescedComm_ = nullptr; bool blockingWait_ = false; + static thread_local uint64_t xcclActiveGroupCounter_; uint64_t seqCollective_{0}; + uint64_t seqP2P_{0}; private: std::mutex kvs_mutex;