diff --git a/CHANGELOG.md b/CHANGELOG.md index 480b0dfc3b..245527e1a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,6 @@ # NVIDIA CUTLASS Changelog -## 3.5 (2024-03-18) +## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) - Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). @@ -12,8 +12,13 @@ - [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. +- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. + + [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934). + + [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564). - Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). - Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). +- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. +- Fixes to greatly reduce build warnings. - Updates and bugfixes from the community (thanks!) ## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14) diff --git a/CMakeLists.txt b/CMakeLists.txt index dd06a60552..4933b6c929 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,7 +38,7 @@ else() endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") -set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") +set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++17 if set") # To reduce duplicate version locations, parse the version out of the # main versions.h file and reuse it here. @@ -332,6 +332,18 @@ endif() +# Warnings-as-error exceptions and warning suppressions for Clang builds +if (CMAKE_CXX_COMPILER_ID MATCHES "Clang") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=implicit-int-conversion ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=implicit-int-conversion" ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pass-failed ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=pass-failed" ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=inconsistent-missing-override ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=inconsistent-missing-override" ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion ") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-sign-conversion" ) +endif() + if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") @@ -357,6 +369,7 @@ if (CUTLASS_ENABLE_OPENMP_TESTS) message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.") endif() endif() + if(UNIX) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing) diff --git a/README.md b/README.md index 98ddbb0195..865ffb76ff 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # CUTLASS 3.5 -_CUTLASS 3.5 - March 2024_ +_CUTLASS 3.5 - April 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -45,18 +45,21 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im CUTLASS 3.5 is an update to CUTLASS adding: -- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) +- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp). + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). - + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms. + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! - Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. -- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x +- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x. + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. +- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. - Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). - Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). +- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. +- Fixes to greatly reduce build warnings. - Updates and bugfixes from the community (thanks!) Minimum requirements: diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index 7a800c0b32..c0395f5899 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -265,6 +265,10 @@ constexpr int NumStages = 3; // Which iterator algorithm to use: Analytic or Optimized static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; +// Is the output packed or strided +// Use kStride if using strided output +static cutlass::conv::StrideSupport const OutputStride = cutlass::conv::StrideSupport::kUnity; + // The epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // Data type of output matrix. @@ -289,7 +293,8 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< SwizzleThreadBlock, NumStages, cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm + IteratorAlgorithm, + OutputStride >::Kernel; // Type of the actual kernel diff --git a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu index 2c3c9011b2..1ecd38ee9b 100644 --- a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu +++ b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu @@ -36,7 +36,7 @@ implicitly to tf32 inside the GEMM kernel which means no change is needed to acc fp32 data by using NVIDIA Ampere architecture. We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated -using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). +using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). The trick is very simple a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big @@ -45,11 +45,11 @@ The trick is very simple a_small x b_small is discarded because they are too small. -This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 +This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 results (SGEMM using SIMT) and against FP64 results (DGEMM) -To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -OpMultiplyAddFastF32. +To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to +OpMultiplyAddFastF32. Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference @@ -97,14 +97,14 @@ struct Result { double l2_norm_fp32_vs_fp64; // ctor - Result( + Result( int m, int n, int k, double runtime_ms, double gflops, double l2_norm_3xtf32_vs_fp64, double l2_norm_1xtf32_vs_fp64, - double l2_norm_fp32_vs_fp64) : + double l2_norm_fp32_vs_fp64) : m(m), n(n), k(k), - runtime_ms(runtime_ms), gflops(gflops), + runtime_ms(runtime_ms), gflops(gflops), l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} @@ -147,7 +147,7 @@ struct Options { int iterations; int seed; bool benchmark; - + Options(): help(false), problem_size({3456, 4096, 4096}), @@ -190,7 +190,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -227,9 +227,9 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = problem_size.product(); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -272,10 +272,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // Number of pipelines you want to use constexpr int NumStages = 3; -// Alignment +// Alignment constexpr int Alignment = 4; -// +// // Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) // @@ -296,7 +296,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::Gemm< EpilogueOp, SwizzleThreadBlock, NumStages, - Alignment, + Alignment, Alignment, false, cutlass::arch::OpMultiplyAddFastF32>; @@ -318,7 +318,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::Gemm< EpilogueOp, SwizzleThreadBlock, NumStages, - Alignment, + Alignment, Alignment, false, cutlass::arch::OpMultiplyAdd>; @@ -356,7 +356,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -397,7 +397,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -411,7 +411,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Gemm output (D) for GEMM_F64 cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N // Gemm output (D) for GEMM_3xTF32 @@ -426,7 +426,7 @@ bool run(Options &options) { cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -464,7 +464,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_3xTF32 gemm_op_3xTF32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -568,7 +568,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_1xTF32 gemm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -627,7 +627,7 @@ bool run(Options &options) { tensor_d_F32.sync_host(); //////////////////////////////////////////////////////////////////////////////// - /////// Compute l2 norms + /////// Compute l2 norms //////////////////////////////////////////////////////////////////////////////// // l2 norm 3xTF32 vs F64 @@ -664,7 +664,7 @@ bool run(Options &options) { std::cout << "GFLOPs: " << result.gflops << std::endl; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; @@ -673,11 +673,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -690,7 +690,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { @@ -716,17 +716,17 @@ int main(int argc, const char **argv) { if (options.benchmark) { for (int k = 4; k <= 65536; k *= 2) { - + options.problem_size[2] = k; - + printf("Gemm problem size: %d x %d x %d\n", \ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); - + if (!options.valid()) { std::cerr << "Invalid problem." << std::endl; return -1; } - + result &= run(options); } } else { diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu index c4e6e958a6..18375f6dd3 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu @@ -34,7 +34,7 @@ difference is that this example uses 3xtf32 on complex gemm. To enable this feature, the only change needs to make is to change OpMultiplyAddComplex - to OpMultiplyAddComplexFastF32. + to OpMultiplyAddComplexFastF32. */ #include @@ -74,14 +74,14 @@ struct Result { double l2_norm_fp32_vs_fp64; // ctor - Result( + Result( int m, int n, int k, double runtime_ms, double gflops, double l2_norm_3xtf32_vs_fp64, double l2_norm_1xtf32_vs_fp64, - double l2_norm_fp32_vs_fp64) : + double l2_norm_fp32_vs_fp64) : m(m), n(n), k(k), - runtime_ms(runtime_ms), gflops(gflops), + runtime_ms(runtime_ms), gflops(gflops), l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} @@ -124,7 +124,7 @@ struct Options { int iterations; int seed; bool benchmark; - + Options(): help(false), problem_size({3456, 4096, 4096}), @@ -153,7 +153,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -190,9 +190,9 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = problem_size.product(); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -239,7 +239,7 @@ constexpr int NumStages = 3; constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone; constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone; -// +// // Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) // @@ -260,7 +260,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex< EpilogueOp, SwizzleThreadBlock, NumStages, - TransformA, + TransformA, TransformB, cutlass::arch::OpMultiplyAddComplexFastF32>; @@ -281,7 +281,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex< EpilogueOp, SwizzleThreadBlock, NumStages, - TransformA, + TransformA, TransformB, cutlass::arch::OpMultiplyAddComplex>; @@ -296,7 +296,7 @@ bool run(Options &options) { cutlass::HostTensor, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -337,7 +337,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -351,7 +351,7 @@ bool run(Options &options) { cutlass::HostTensor, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Gemm output (D) for GEMM_F64 cutlass::HostTensor, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N // Gemm output (D) for GEMM_3xTF32 @@ -366,7 +366,7 @@ bool run(Options &options) { cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -404,7 +404,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_3xTF32 gemm_op; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -508,7 +508,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_1xTF32 gemm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -569,7 +569,7 @@ bool run(Options &options) { tensor_d_F32.sync_host(); //////////////////////////////////////////////////////////////////////////////// - /////// Compute l2 norms + /////// Compute l2 norms //////////////////////////////////////////////////////////////////////////////// // l2 norm 3xTF32 vs F64 @@ -606,7 +606,7 @@ bool run(Options &options) { std::cout << "GFLOPs: " << result.gflops << std::endl; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; @@ -615,11 +615,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -632,7 +632,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { @@ -658,17 +658,17 @@ int main(int argc, const char **argv) { if (options.benchmark) { for (int k = 4; k <= 65536; k *= 2) { - + options.problem_size[2] = k; - + printf("Gemm problem size: %d x %d x %d\n", \ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); - + if (!options.valid()) { std::cerr << "Invalid problem." << std::endl; return -1; } - + result &= run(options); } } else { diff --git a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu index 098ca8a288..4863ed93e7 100644 --- a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu +++ b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu @@ -36,7 +36,7 @@ implicitly to tf32 inside the SYMM kernel which means no change is needed to acc F32 data by using NVIDIA Ampere architecture. We can use the tf32 mode of tensor core to emulate a fast accurate SYMM kernel which is accelerated -using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). +using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). The trick is very simple a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big @@ -45,11 +45,11 @@ The trick is very simple a_small x b_small is discarded because they are too small. -This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 +This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 results (SSYMM from cuBLAS) and against F64 results (DSYMM from CUTLASS) -To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -OpMultiplyAddFastF32. +To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to +OpMultiplyAddFastF32. Now, we have two different flavors of SSYMM in the profiler for Ampere: @@ -95,7 +95,7 @@ struct Options { float beta; std::string rand_mode; int seed; - + Options(): help(false), problem_size({4096, 4096, 4096}), @@ -137,7 +137,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -207,10 +207,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // Number of pipelines you want to use constexpr int NumStages = 3; -// Alignment +// Alignment constexpr int Alignment = 4; -// +// // CUTLASS Symm Operators (SSYM: Symm_3xTF32, Symm_1xTF32, DSYMM: Symm_F64) // @@ -233,7 +233,7 @@ using Symm_3xTF32 = cutlass::gemm::device::Symm< EpilogueOp, SwizzleThreadBlock, NumStages, - 1, // Symmetric matrix is always align 1 + 1, // Symmetric matrix is always align 1 Alignment, false, cutlass::arch::OpMultiplyAddFastF32>; @@ -257,7 +257,7 @@ using Symm_1xTF32 = cutlass::gemm::device::Symm< EpilogueOp, SwizzleThreadBlock, NumStages, - 1, // Symmetric matrix is always align 1 + 1, // Symmetric matrix is always align 1 Alignment, false, cutlass::arch::OpMultiplyAdd>; @@ -298,7 +298,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -339,7 +339,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -353,7 +353,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Symm output (D) for SYMM_3xTF32 cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N // Symm output (D) for SYMM_1xTF32 @@ -375,7 +375,7 @@ bool run(Options &options) { #if CUTLASS_ENABLE_CUBLAS cutlass::reference::host::TensorCopy(tensor_d_cublasF32.host_view(), tensor_d_F32.host_view()); #endif - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -430,7 +430,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_3xTF32 symm_op_3xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = symm_op_3xtf32.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -477,7 +477,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_1xTF32 symm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = symm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -524,7 +524,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_F64 symm_op_f64; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_f64 = symm_op_f64.can_implement(arguments_f64); CUTLASS_CHECK(status_f64); @@ -568,7 +568,7 @@ bool run(Options &options) { static_cast(&beta), static_cast(tensor_d_cublasF32.device_data()), int(tensor_d_cublasF32.layout().stride(0)) - ); + ); cudaDeviceSynchronize(); @@ -576,7 +576,7 @@ bool run(Options &options) { #endif //////////////////////////////////////////////////////////////////////////////// - /// 7. Compute l2 norms + /// 7. Compute l2 norms //////////////////////////////////////////////////////////////////////////////// #if CUTLASS_ENABLE_CUBLAS @@ -605,20 +605,20 @@ bool run(Options &options) { double l2_norm_3xtf32_vs_cublasf32 = cutlass::reference::host::TensorRelativeErrorMetric( tensor_d_3xTF32.host_view(), tensor_d_cublasF32.host_view()); #endif - + // l2 norm 3xTF32 vs 1xTF32 double l2_norm_3xtf32_vs_1xtf32 = cutlass::reference::host::TensorRelativeErrorMetric( tensor_d_3xTF32.host_view(), tensor_d_1xTF32.host_view()); /////////////////////////////////////////////////////////////////////////////// - // Print kernel info and L2 norms + // Print kernel info and L2 norms std::cout << "Problem Size: (" << problem_size.m() << "," << problem_size.n() << "," << problem_size.k() << ") " << "Alpha: " << alpha << "," << " Beta: " << beta << std::endl; std::cout << std::fixed; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific #if CUTLASS_ENABLE_CUBLAS << " - cuBLAS F32 error with F64 reference : " << l2_norm_cublasf32_vs_f64 << std::endl #endif @@ -633,11 +633,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -650,7 +650,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp index c0337aba7b..7873071084 100644 --- a/include/cute/algorithm/cooperative_copy.hpp +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -71,85 +71,103 @@ cooperative_copy(uint32_t const& tid, // Precondition on tid in DEBUG assert(tid < NumThreads); - // Precondition on pointer alignment in DEBUG - assert(is_byte_aligned(raw_pointer_cast(src.data()))); - assert(is_byte_aligned(raw_pointer_cast(dst.data()))); - // - // Determine val+thr vectorization based on src/dst size and number of threads - // NOTE: This heuristic promotes parallelization over vectorization - // - - constexpr int elem_bits = sizeof_bits_v; - - // The number of elements that can be vectorized in values - constexpr int common_elem = decltype(max_common_vector(src, dst))::value; - constexpr int common_bits = common_elem * elem_bits; - constexpr int total_elem = decltype(size(src))::value; - constexpr int total_bits = total_elem * elem_bits; - static_assert(total_bits % NumThreads == 0); - constexpr int total_bits_per_thr = total_bits / NumThreads; - // If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits - constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr); - - // Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits - constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast(MaxVecBits)); - // Convert back to number of elements, safe_div - static_assert((vec_bits % elem_bits) == 0); - constexpr int vec_elem = vec_bits / elem_bits; - - // Use only part of threads if there's not enough work for all threads - constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0) - ? NumThreads - : (total_elem / vec_elem); - - // The common layout of the two tensors that can be vectorized over threads - // vidx -> coord - auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()), - get_nonswizzle_portion(dst.layout())); - - // Scale up the common_layout to cover the entire tensors - // vidx -> coord - auto full_perm = tile_to_shape(make_layout(common_layout), size(src)); - - // Create the Tiler - // ((vid,tid),iter) - auto layout_vt = logical_divide(full_perm, Layout, Int>>{}); - - // Apply and slice - Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_); - Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_); - - // Should account for vec_bits < 8 and/or vec_elem <= 1 - // And also account for subbyte types, which could cause race conditions - // Want to ENFORCE sufficient vectorization in those cases - static_assert((vec_bits >= 8), "No support for subbyte copying"); - using VecType = uint_bit_t; + + // Fallback - slow path, naive copy, vectorization disabled + if constexpr(size(SrcLayout{}) % NumThreads != 0) { + int index = static_cast(tid); + CUTE_UNROLL + for(int i = 0; i < ceil_div(size(SrcLayout{}), NumThreads); i++) { + if(index < size(SrcLayout{})) { + dst[index] = src[index]; + } + index += NumThreads; + } + } else { + // Fast path with vectorization + + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); + constexpr int elem_bits = sizeof_bits_v; + + // + // Determine val+thr vectorization based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // + + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src, dst))::value; + constexpr int common_bits = common_elem * elem_bits; + constexpr int total_elem = decltype(size(src))::value; + constexpr int total_bits = total_elem * elem_bits; + static_assert(total_bits % NumThreads == 0); + constexpr int total_bits_per_thr = total_bits / NumThreads; + // If there are too many threads to allow a full elem copy, trunc the thrs and use elem_bits + constexpr int max_vec_bits_by_thr = cute::max(elem_bits, total_bits_per_thr); + + // Cap the vectorization to the common bits, the max_vec_bits_by_thr, and the MaxVecBits + constexpr int vec_bits = cute::min(common_bits, max_vec_bits_by_thr, static_cast(MaxVecBits)); + // Convert back to number of elements, safe_div + static_assert((vec_bits % elem_bits) == 0); + constexpr int vec_elem = vec_bits / elem_bits; + + // Use only part of threads if there's not enough work for all threads + constexpr int vec_thrs = (total_elem % (vec_elem * NumThreads) == 0) + ? NumThreads + : (total_elem / vec_elem); + static_assert(vec_thrs <= NumThreads); + + // The common layout of the two tensors that can be vectorized over threads + // vidx -> coord + auto common_layout = max_common_layout(get_nonswizzle_portion(src.layout()), + get_nonswizzle_portion(dst.layout())); + + // Scale up the common_layout to cover the entire tensors + // vidx -> coord + auto full_perm = tile_to_shape(make_layout(common_layout), size(src)); + + // Create the Tiler + // ((vid,tid),iter) + auto layout_vt = logical_divide(full_perm, Layout, Int>>{}); + + // Apply and slice + Tensor src_v = src.compose(layout_vt)(make_coord(_,tid),_); + Tensor dst_v = dst.compose(layout_vt)(make_coord(_,tid),_); + + // Should account for vec_bits < 8 and/or vec_elem <= 1 + // And also account for subbyte types, which could cause race conditions + // Want to ENFORCE sufficient vectorization in those cases + static_assert((vec_bits >= 8), "No support for subbyte copying"); + using VecType = uint_bit_t; #if 0 - if (thread0()) { - print(" "); print("NumThreads: "); print(NumThreads); print("\n"); - print(" "); print("src: "); print(src); print("\n"); - print(" "); print("dst: "); print(dst); print("\n"); - print(" "); print("common_layout: "); print(common_layout); print("\n"); - print(" "); print("full_perm: "); print(full_perm); print("\n"); - print(" "); print("Used vector: "); print(vec_elem); print("\n"); - print(" "); print("Used threads: "); print(vec_thrs); print("\n"); - print(" "); print("layout_vt: "); print(layout_vt); print("\n"); - print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n"); - print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n"); - print(" "); print("src_v: "); print(src_v); print("\n"); - print(" "); print("dst_v: "); print(dst_v); print("\n"); - print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); - print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); - } + if (thread0()) { + print(" "); print("cooperative_copy -- vec\n"); + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("full_perm: "); print(full_perm); print("\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("layout_vt: "); print(layout_vt); print("\n"); + print(" "); print("src.compose(layout_vt): "); print(src.compose(layout_vt)); print("\n"); + print(" "); print("dst.compose(layout_vt): "); print(dst.compose(layout_vt)); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } #ifdef __CUDA_ARCH__ - __syncthreads(); + __syncthreads(); #endif #endif - // If we're using all threads (static) or the tid is in in-range (dynamic) - if (vec_thrs >= NumThreads or tid < vec_thrs) { - return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + // If we're using all threads (static) or the tid is in in-range (dynamic) + if (vec_thrs >= NumThreads or tid < vec_thrs) { + return copy_if(TrivialPredTensor{}, recast(src_v), recast(dst_v)); + } } } diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp index 32cec54b92..b83881590b 100644 --- a/include/cute/algorithm/cooperative_gemm.hpp +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -35,6 +35,7 @@ #include +#include #include #include @@ -44,40 +45,37 @@ namespace cute { // -// Collective Shared-Memory GEMMs +// Cooperative Shared-Memory GEMMs // +namespace detail { + +// Predicated Cooperative GEMM template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void -cooperative_gemm(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, - BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) +cooperative_gemm_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C { - CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM - CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN - CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK - using TypeA = typename TA::value_type; using TypeB = typename TB::value_type; using TypeC = typename TC::value_type; - static_assert(is_same_v>, TypeA>, - "ALoadTransformOp functor must accept and return value of type TA::value_type"); - static_assert(is_same_v>, TypeB>, - "BLoadTransformOp functor must accept and return value of type TB::value_type"); - // Original, static size of the problem auto M = size<0>(sC); auto N = size<1>(sC); @@ -88,39 +86,14 @@ cooperative_gemm(ThrMMA const& thr_mma, auto BLK_N = tile_size<1>(thr_mma); auto BLK_K = tile_size<2>(thr_mma); - // Compute the "residues" - auto m_residue = M - BLK_M * (ceil_div(M, BLK_M) - Int<1>{}); // (0,BLK_M] - auto n_residue = N - BLK_N * (ceil_div(N, BLK_N) - Int<1>{}); // (0,BLK_N] - auto k_residue = K - BLK_K * (ceil_div(K, BLK_K) ); // (-BLK_K,0] - - // Shift the origin so k_residue is zeroth tile - sA.data() = &sA(0,k_residue); - sB.data() = &sB(0,k_residue); - -#if 0 - if (thread0()) { - printf("%d in BLK_M (%d)\n", int(m_residue), int(BLK_M)); - printf("%d in BLK_N (%d)\n", int(n_residue), int(BLK_N)); - printf("%d in BLK_K (%d)\n", int(k_residue), int(BLK_K)); - } -#endif - // // MMA Partitioning // - // Round the layout extents up to BLK_X - Tensor rounded_sA = sA.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(K, BLK_K) * BLK_K)); - Tensor rounded_sB = sB.compose(make_shape(ceil_div(N, BLK_N) * BLK_N, ceil_div(K, BLK_K) * BLK_K)); - Tensor rounded_sC = sC.compose(make_shape(ceil_div(M, BLK_M) * BLK_M, ceil_div(N, BLK_N) * BLK_N)); - -#if 0 - if (thread0()) { - print("rounded_sA: "); print(rounded_sA); print("\n"); - print("rounded_sB: "); print(rounded_sB); print("\n"); - print("rounded_sC: "); print(rounded_sC); print("\n"); - } -#endif + // Round the layout extents up to BLK_X to satisfy MMA partitioning safety + Tensor rounded_sA = sA.compose(make_shape(round_up(M, BLK_M), round_up(K, BLK_K))); + Tensor rounded_sB = sB.compose(make_shape(round_up(N, BLK_N), round_up(K, BLK_K))); + Tensor rounded_sC = sC.compose(make_shape(round_up(M, BLK_M), round_up(N, BLK_N))); // Partition the sA and sB tiles across the threads for the MMA Tensor tCsA = thr_mma.partition_A(rounded_sA); // (MMA,MMA_M,MMA_K) @@ -133,6 +106,13 @@ cooperative_gemm(ThrMMA const& thr_mma, #if 0 if (thread0()) { + print(" sA: "); print( sA); print("\n"); + print(" sB: "); print( sB); print("\n"); + print(" sC: "); print( sC); print("\n"); + print("r_sA: "); print(rounded_sA); print("\n"); + print("r_sB: "); print(rounded_sB); print("\n"); + print("r_sC: "); print(rounded_sC); print("\n"); + print(thr_mma); print("tCsA: "); print(tCsA); print("\n"); print("tCsB: "); print(tCsB); print("\n"); print("tCsC: "); print(tCsC); print("\n"); @@ -146,56 +126,59 @@ cooperative_gemm(ThrMMA const& thr_mma, // PREDICATION // - // Allocate the preds for only the MMA-mode of tCsA and tCsB - Tensor tCpA = make_tensor(size<0>(tCsA)); - Tensor tCpB = make_tensor(size<0>(tCsB)); - - // Create coordinate tensors on a single compute block for predication - Tensor cA = make_identity_tensor(make_shape(BLK_M, BLK_K)); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cB = make_identity_tensor(make_shape(BLK_N, BLK_K)); // (BLK_M,BLK_K) -> (blk_n,blk_k) + // Create coordinate tensors for the problem + Tensor cA = make_identity_tensor(shape(rounded_sA)); // (M,K) -> (m,k) + Tensor cB = make_identity_tensor(shape(rounded_sB)); // (N,K) -> (n,k) // Repeat partitioning with thr_mma - Tensor tCcA = thr_mma.partition_A(cA); // (MMA,1,1) -> (blk_m,blk_k) - Tensor tCcB = thr_mma.partition_B(cB); // (MMA,1,1) -> (blk_n,blk_k) + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k) - // Populate the m and n predicates + // Allocate the preds for MMA- and MMA_MN-modes + Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); + Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); + + // Populate the predicates on M and N CUTE_UNROLL for (int i = 0; i < size(tCpA); ++i) { - tCpA(i) = elem_less(get<0>(tCcA(i)), m_residue); + tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA)); } CUTE_UNROLL for (int i = 0; i < size(tCpB); ++i) { - tCpB(i) = elem_less(get<0>(tCcB(i)), n_residue); + tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); } #if 0 - printf("Thr %d: A(%d,%d):%d B(%d,%d):%d\n", - threadIdx.x, - int(get<0>(tCcA(0))), int(get<1>(tCcA(0))), int(tCpA(0)), - int(get<0>(tCcB(0))), int(get<1>(tCcB(0))), int(tCpB(0))); + if (thread0()) { + print(" cA: "); print( cA); print("\n"); + print(" cB: "); print( cB); print("\n"); + print("tCcA: "); print(tCcA); print("\n"); + print("tCcB: "); print(tCcB); print("\n"); + print_tensor(tCpA); + print_tensor(tCpB); + } #endif // - // PREFETCH k_block = 0 (with k-predication) + // PREFETCH k_block = 0 + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial // + constexpr int K_BLOCK_MAX = size<2>(tCrA); + CUTE_UNROLL - for (int i = 0; i < size<0>(tCsA); ++i) { // Copy MMA_I - if (k_residue == 0 || get<1>(tCcA(i)) >= -k_residue) { // k_block = 0, predicated on k - CUTE_UNROLL - for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M, predicated on m - tCrA(i,m,0) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; - } + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,0)) : TypeA{}; } } - CUTE_UNROLL - for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I - if (k_residue == 0 || get<1>(tCcB(i)) >= -k_residue) { // k_block = 0, predicated on k - CUTE_UNROLL - for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N, predicated on n - tCrB(i,n,0) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; - } + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,0)) : TypeB{}; } } // @@ -205,34 +188,31 @@ cooperative_gemm(ThrMMA const& thr_mma, // Clear accumulators clear(tCrC); - constexpr int K_BLOCK_MAX = size<2>(tCrA); - CUTE_UNROLL for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) { - // static-if load the next k_block. No k-predication required on these loads. - if (k_block < K_BLOCK_MAX-1) + if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block { - // Load the next k_block - int k_next = k_block + 1; + int k_next = k_block + 1; // Load k_next block + + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial CUTE_UNROLL - for (int m = 0; m < size<1>(tCsA); ++m) { // Copy MMA_M + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M CUTE_UNROLL - for (int i = 0; i < size<0>(tCsA); ++i) { // Copy_if MMA_I predicated on m - tCrA(i,m,k_next) = (m_residue == BLK_M || m < size<1>(tCsA)-1 || tCpA(i)) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? sA_load_op(tCsA(i,m,k_next)) : TypeA{}; } } - CUTE_UNROLL - for (int n = 0; n < size<1>(tCsB); ++n) { // Copy MMA_N + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N CUTE_UNROLL - for (int i = 0; i < size<0>(tCsB); ++i) { // Copy MMA_I predicated on n - tCrB(i,n,k_next) = (n_residue == BLK_N || n < size<1>(tCsB)-1 || tCpB(i)) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? sB_load_op(tCsB(i,n,k_next)) : TypeB{}; } } } - // GEMM on k_block in registers gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); } @@ -241,73 +221,288 @@ cooperative_gemm(ThrMMA const& thr_mma, // Epilogue // - Tensor cC = make_identity_tensor(make_shape(BLK_M, BLK_N)); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor tCcC = thr_mma.partition_C(cC); // (MMA, 1, 1) -> (blk_m,blk_n) + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(rounded_sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) const bool isBetaZero = (beta == Beta{}); // Custom axpby_if for now CUTE_UNROLL - for (int m = 0; m < size<1>(tCsC); ++m) + for (int i = 0; i < size(tCrC); ++i) { - CUTE_UNROLL - for (int n = 0; n < size<2>(tCsC); ++n) + if (elem_less(tCcC(i), shape(sC))) { - CUTE_UNROLL - for (int i = 0; i < size<0>(tCsC); ++i) - { - if ((m_residue == BLK_M || m < size<1>(tCrC)-1 || get<0>(tCcC(i)) < m_residue) && - (n_residue == BLK_N || n < size<2>(tCrC)-1 || get<1>(tCcC(i)) < n_residue)) - { - tCsC(i,m,n) = isBetaZero ? alpha * static_cast(tCrC(i,m,n)) : alpha * static_cast(tCrC(i,m,n)) + beta * static_cast(tCsC(i,m,n)); - } - } + tCsC(i) = sC_store_op(isBetaZero ? alpha * static_cast(tCrC(i)) + : alpha * static_cast(tCrC(i)) + + beta * static_cast(sC_load_op(tCsC(i)))); } } } +// Slow fallback path template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm_predication(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + cooperative_gemm_predication(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op, sC_load_op, sC_store_op); +} + +// Unpredicated Cooperative GEMM +template ::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_smem::value)> +CUTE_HOST_DEVICE +void +cooperative_gemm_no_predication(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + using TypeA = typename TA::value_type; + using TypeB = typename TB::value_type; + using TypeC = typename TC::value_type; + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + // + // MMA Partitioning + // + + Tensor tCsC = thr_mma.partition_C(sC); + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + + using CopyOpAType = SmemCopyOpA; + using CopyOpBType = SmemCopyOpB; + + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrB_copy_view)); // CPY_K + +#if 0 + if (thread0()) { + print(" sA: "); print(sA); print("\n"); + print(" sB: "); print(sB); print("\n"); + print(" sC: "); print(sC); print("\n"); + print(thr_mma); print("\n"); + print("tCsC: "); print(tCsC); print("\n"); + print("tCrA: "); print(tCrA); print("\n"); + print("tCrB: "); print(tCrB); print("\n"); + print("tCrC: "); print(tCrC); print("\n"); + print(smem_thr_copy_A); print("\n"); + print("tCsA: "); print(tCsA); print("\n"); + print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n"); + print(smem_thr_copy_B); print("\n"); + print("tCsB: "); print(tCsB); print("\n"); + print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n"); + } +#endif + + // + // PREFETCH + // + + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrA_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrB_copy_view(_,_,Int<0>{})); + // + // MAINLOOP + // + + // Clear accumulators + clear(tCrC); + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + // static-if load the next k_block. No k-predication required on these loads. + if (k_block < K_BLOCK_MAX-1) + { + // Load the next k_block + int k_next = k_block + 1; // statically unrolled + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrA_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrB_copy_view(_,_,k_next)); + } + + // Transform A and B, relying on the compiler to remove in case of identity ops + cute::transform(tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrB(_,_,k_block), sB_load_op); + + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } + + // + // Epilogue + // + + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + using CopyOpCType = SmemCopyOpC; + Tensor tCrD = thr_mma.make_fragment_C(tCsC); + if(!isBetaZero) { + copy(CopyOpCType{}, tCsC, tCrD); + // Transform C on/after load + cute::transform(tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, sC_store_op); + copy(CopyOpCType{}, tCrD, tCsC); +} + +} // end namespace detail + +template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void -cooperative_gemm(ThrMMA const& thr_mma, +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, Alpha const& alpha, Tensor sA, Tensor sB, Beta const& beta, - Tensor sC) + Tensor sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { - cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using TypeA = typename TA::value_type; + using TypeB = typename TB::value_type; + using TypeC = typename TC::value_type; + + static_assert(is_convertible_v>, TypeA>, + "ALoadTransformOp functor must accept value of type TA::value_type and return value convertible to type TA::value_type"); + static_assert(is_convertible_v>, TypeB>, + "BLoadTransformOp functor must accept value of type TB::value_type and return value convertible to type TB::value_type"); + static_assert(is_convertible_v>, TypeC>, + "CLoadTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); + static_assert(is_convertible_v>, TypeC>, + "CStoreTransformOp functor must accept value of type TC::value_type and return value convertible to type TC::value_type"); + + static constexpr bool compat = weakly_compatible(tile_shape(TiledMMA{}), + make_shape(size<0>(sA), size<0>(sB), size<1>(sA))); + if constexpr (compat) { + detail::cooperative_gemm_no_predication( + thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); + } else { + detail::cooperative_gemm_predication( + thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); + } } template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> CUTE_HOST_DEVICE void -gemm(ThrMMA const& thr_mma, - Alpha const& alpha, - Tensor sA, - Tensor sB, - Beta const& beta, - Tensor sC, - ALoadTransformOp const& sA_load_op /* transforms A values before used in GEMM */, - BLoadTransformOp const& sB_load_op /* transforms B values before used in GEMM */) +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor sA, + Tensor sB, + Beta const& beta, + Tensor sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { - cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, sA_load_op, sB_load_op); + using CopyOpA = AutoVectorizingCopyWithAssumedAlignment>; + using CopyOpB = AutoVectorizingCopyWithAssumedAlignment>; + using CopyOpC = AutoVectorizingCopyWithAssumedAlignment>; + cooperative_gemm( + thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); } +// Legacy overload of cute::gemm for backwards-compatibility template ::value && BLayout::rank == 2 && is_smem::value && CLayout::rank == 2 && is_smem::value)> @@ -318,9 +513,17 @@ gemm(ThrMMA const& thr_mma, Tensor sA, Tensor sB, Beta const& beta, - Tensor sC) + Tensor sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C { - cooperative_gemm(thr_mma, alpha, sA, sB, beta, sC, identity() /* sA_load_op */, identity() /* sB_load_op */); + // Goes directly to the slow path to avoid getting thread_idx from thr_mma + detail::cooperative_gemm_predication( + thr_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op + ); } } // end namespace cute diff --git a/include/cute/arch/copy_sm50.hpp b/include/cute/arch/copy_sm50.hpp new file mode 100644 index 0000000000..9cf0efcdf5 --- /dev/null +++ b/include/cute/arch/copy_sm50.hpp @@ -0,0 +1,72 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 500 + #define CUTE_ARCH_WARP_SHUFFLE_ENABLED 1 +#endif + +namespace cute +{ + +struct SM50_Shuffle_U32_2x2Trans +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = src0; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 1); + + uint32_t x1 = src1; + uint32_t y1 = __shfl_xor_sync(0xffffffff, x1, 1); + + if (threadIdx.x % 2 == 0) { + dst1 = y0; + } + else { + dst0 = y1; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + + +} // end namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp index 06add577e8..92e342510a 100644 --- a/include/cute/arch/util.hpp +++ b/include/cute/arch/util.hpp @@ -117,7 +117,7 @@ cast_smem_ptr_to_uint(void const* const ptr) uint32_t smem_ptr; asm( - "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" : "=r"(smem_ptr) : "l"(ptr)); return smem_ptr; @@ -132,11 +132,47 @@ cast_smem_ptr_to_uint(void const* const ptr) #endif } +namespace detail { + // -// Utility for pointer interfaces +// Wrapper for MMAOp::fma // -namespace detail { +template +struct CallFMA { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return MmaOp::fma(static_cast(args)...); + } +}; + +// +// Wrapper for CopyOp::copy +// + +template +struct CallCOPY { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return CopyOp::copy(static_cast(args)...); + } +}; + +// +// Utility for exploding pointers/arrays/tensors into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence) +{ + return fn(a[I]...); +} template + class PtrE, int... Ie> CUTE_HOST_DEVICE constexpr void -explode_with_d_scaling(Fn fn, +explode(Fn fn, + PtrD&& d, int_sequence, PtrA&& a, int_sequence, PtrB&& b, int_sequence, PtrC&& c, int_sequence, - ParamType&& p0) + PtrE&& e, int_sequence) { - return fn(a[Ia]..., b[Ib]..., c[Ic]..., p0); + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]...); } template + class PtrD, int... Id, + class PtrA, int... Ia, + class PtrB, int... Ib, + class PtrC, int... Ic, + class PtrSFA, int... Isfa, + class PtrSFB, int... Isfb> CUTE_HOST_DEVICE constexpr void -explode_with_d_scaling(Fn fn, - PtrD&& d, int_sequence, - PtrA&& a, int_sequence, - PtrB&& b, int_sequence, - PtrC&& c, int_sequence, - ParamType&& p0) +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrSFA&& sfa, int_sequence, + PtrSFB&& sfb, int_sequence) { - return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., p0); + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., sfa[Isfa]..., sfb[Isfb]...); } +// +// Utility for exploding tuples into functions +// -} // end namespace detail - -template +template CUTE_HOST_DEVICE constexpr void -explode(Fn fn, PtrS&& s, PtrD&& d) +explode_tuple(Fn fn, + TupleA&& a, int_sequence) { - return detail::explode(fn, - s, make_int_sequence{}, - d, make_int_sequence{}); + return fn(get(a)...); } -template +template CUTE_HOST_DEVICE constexpr void -explode(Fn fn, PtrA&& a, PtrB&& b, PtrC&& c) +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence) { - return detail::explode(fn, - a, make_int_sequence{}, - b, make_int_sequence{}, - c, make_int_sequence{}); + return fn(get(a)..., get(b)...); } -template +template CUTE_HOST_DEVICE constexpr void -explode(Fn fn, PtrD&& d, PtrA&& a, PtrB&& b, PtrC&& c) +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence, + TupleC&& c, int_sequence) { - return detail::explode(fn, - d, make_int_sequence{}, - a, make_int_sequence{}, - b, make_int_sequence{}, - c, make_int_sequence{}); + return fn(get(a)..., get(b)..., get(c)...); } +} // end namespace detail + } // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp index d1cd3d4b71..48a5fd168b 100644 --- a/include/cute/atom/copy_atom.hpp +++ b/include/cute/atom/copy_atom.hpp @@ -756,6 +756,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and //////////////////////////////////////////////////////////////////////////////////////////////////// +#include #include #include #include diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index b6259b589e..2aa3ba5774 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -92,59 +92,6 @@ struct Copy_Traits> using RefLayout = SrcLayout; }; -namespace detail { - -// Utility for exploding pointers, arrays, or tensors into Operation::copy -template -CUTE_HOST_DEVICE constexpr -void -copy_explode_index(PtrSrc&& s, int_sequence, - PtrDst&& d, int_sequence) -{ - return Operation::copy(s[Is]..., d[Id]...); -} - -// Utility for exploding tuples into ::copy -template -CUTE_HOST_DEVICE constexpr -void -copy_explode(TupleArg&& t, int_sequence) -{ - return Operation::copy(get(static_cast(t))...); -} - -template -CUTE_HOST_DEVICE constexpr -void -copy_explode(TupleSrc&& s, int_sequence, - TupleDst&& d, int_sequence) -{ - return Operation::copy(get(static_cast(s))..., - get(static_cast(d))...); -} - -template -CUTE_HOST_DEVICE constexpr -void -copy_explode(TupleAux&& a, int_sequence, - TupleSrc&& s, int_sequence, - TupleDst&& d, int_sequence) -{ - return Operation::copy(get(static_cast(a))..., - get(static_cast(s))..., - get(static_cast(d))...); -} - -} // end namespace detail - // // Generic copy_unpack for common argument-based Copy_Traits // @@ -177,8 +124,9 @@ copy_unpack(Copy_Traits const&, CUTE_STATIC_ASSERT_V(size(rD) == Int{}, "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); - detail::copy_explode_index(rS, make_int_sequence{}, - rD, make_int_sequence{}); + detail::explode(detail::CallCOPY{}, + rS, make_int_sequence{}, + rD, make_int_sequence{}); } // diff --git a/include/cute/atom/copy_traits_sm50.hpp b/include/cute/atom/copy_traits_sm50.hpp new file mode 100644 index 0000000000..8be0ef7bba --- /dev/null +++ b/include/cute/atom/copy_traits_sm50.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1, _64>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp index 34e71ed665..15d9979c92 100644 --- a/include/cute/atom/copy_traits_sm90_im2col.hpp +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -73,14 +73,16 @@ struct TMA_LOAD_IM2COL_Unpack CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset)); if constexpr (detail::is_prefetch) { - return detail::copy_explode(traits.opargs_, tuple_seq{}, - src_coord_cwhdn_offset_srt, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); } else { static_assert(is_smem::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory."); void* dst_ptr = cute::raw_pointer_cast(dst.data()); - return detail::copy_explode(traits.opargs_, tuple_seq{}, - make_tuple(dst_ptr), seq<0>{}, - src_coord_cwhdn_offset_srt, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); } } }; @@ -349,8 +351,9 @@ struct Copy_Traits void const* const src_ptr = cute::raw_pointer_cast(src.data()); auto dst_coord = flatten(take<0,3>(dst(Int<0>{}))); - return detail::copy_explode(make_tuple(desc_ptr, src_ptr), seq<0,1>{}, - dst_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); } }; @@ -537,26 +540,9 @@ make_im2col_tma_copy_desc( return cute::make_tuple(tma_desc, tma_tensor); } -/// Make a TiledCopy for im2col TMA load. -/// -/// @param copy_op The copy implementation: either -/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST. -/// -/// @param tensor_cwhdn The global tensor to use for im2col TMA loads. -/// For Fprop convolutions, this is the activation tensor. This is -/// the "original tensor that points to global memory, not the -/// coordinate (im2col-transformed) tensor. -/// -/// @param slayout Layout of shared memory tile. -/// -/// @param stride_whd The traversal strides convolution -/// parameter. -/// -/// @return TiledCopy specialization for im2col TMA loads. template CUTE_HOST_RTC auto -make_tma_copy_im2col(CopyOp const& copy_op, - Tensor const& gtensor, - SLayout const& slayout, - Layout const& cta_t_map, // CTA tid -> logical TMA tid - Layout const& cta_v_map, // CTA vid -> gmem coord +make_tma_atom_im2col(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + int32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map, // V: CTA val idx -> gmem mode LowerCornerStride const& lower_corner_whd, UpperCornerStride const& upper_corner_whd, LowerPaddingStride const& lower_padding_whd, UpperPaddingStride const& upper_padding_whd, - TraversalStride const& stride_whd, // traversal stride + TraversalStride const& stride_whd, // traversal stride LowerSRTStride const& lower_srt, - DilationStride const& stride_srt) // dilation + DilationStride const& stride_srt) // dilation { // // TMA parameter checking @@ -586,8 +572,6 @@ make_tma_copy_im2col(CopyOp const& copy_op, CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), "TMA requires CTA_Tile and SLayout top-level shape equivalence."); - CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, - "Number of active CTAs in TMA must divide domain size of slayout."); // // TMA slayout manipulation @@ -632,7 +616,7 @@ make_tma_copy_im2col(CopyOp const& copy_op, auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); // Split according to the portion each multicast CTA will be responsible for - auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), cosize(cta_t_map))); + auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), num_multicast)); #if 0 print("glayout_basis : "); print(glayout_basis); print("\n"); @@ -668,41 +652,106 @@ make_tma_copy_im2col(CopyOp const& copy_op, // using T = typename GEngine::value_type; - constexpr int num_bits_per_tma = decltype(size<0>(tma_layout_vt))::value * sizeof(T) * 8; + constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof(T) * 8; using Traits = Copy_Traits, decltype(tma_tensor)>; + using Atom = Copy_Atom; #if 0 - print("num_bits : "); print(NumBitsPerTMA{}); print("\n"); + print("num_bits : "); print(num_bits_per_tma); print("\n"); #endif Traits tma_traits{tma_desc, tma_tensor}; + // Return the Copy_Atom + return Atom{tma_traits}; +} + +/// Make a TiledCopy for im2col TMA load. +/// +/// @param copy_op The copy implementation: either +/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST. +/// +/// @param tensor_cwhdn The global tensor to use for im2col TMA loads. +/// For Fprop convolutions, this is the activation tensor. This is +/// the "original tensor that points to global memory, not the +/// coordinate (im2col-transformed) tensor. +/// +/// @param slayout Layout of shared memory tile. +/// +/// @param stride_whd The traversal strides convolution +/// parameter. +/// +/// @return TiledCopy specialization for im2col TMA loads. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map, // CTA vid -> gmem coord + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) // dilation +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, + "Number of active CTAs in TMA must divide domain size of slayout."); + + Copy_Atom atom = make_tma_atom_im2col(copy_op, gtensor, slayout, cosize(cta_t_map), cta_v_map, + lower_corner_whd, upper_corner_whd, lower_padding_whd, + upper_padding_whd, stride_whd, lower_srt, stride_srt); + // // Construct the TiledCopy // auto cta_tiler = product_each(shape(cta_v_map)); - // (CTA V, CTA T) -> smem_coord - auto layout_vt = composition(inv_smem_layout, make_layout(shape(tma_layout_vt))); + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); // Scale that up to cover all of the smem_coords - // - // The smem vector might not cover all of the tile, - // so multiply it up to cover the entire tile. - // "T" here (the parallel index) is a CTA index. - auto layout_VT = tile_to_shape(layout_vt, make_shape(size(cta_v_map)/size<1>(layout_vt), size<1>(layout_vt))); - // Flip it and change the domain of the T from logical thr to thr_idx - auto layout_TV = make_layout(composition(layout<1>(layout_VT), cta_t_map), layout<0>(layout_VT)); + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); #if 0 print("cta_tiler : "); print(cta_tiler); print("\n"); - print("layout_VT : "); print(layout_VT); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); print("layout_TV : "); print(layout_TV); print("\n"); #endif - using T = typename GEngine::value_type; - return TiledCopy, decltype(layout_TV), decltype(cta_tiler)>{tma_traits}; + return TiledCopy{atom}; } /// Make a TiledCopy for im2col TMA with no offsets. diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index 16b2a648b9..d42c82c915 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -69,8 +69,9 @@ struct TMA_LOAD_Unpack { auto src_coord = src.data().coord_; if constexpr (detail::is_prefetch) { - return detail::copy_explode(traits.opargs_, tuple_seq{}, - src_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); } else { static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); void* dst_ptr = cute::raw_pointer_cast(dst.data()); @@ -81,9 +82,10 @@ struct TMA_LOAD_Unpack blockIdx.x, blockIdx.y, blockIdx.z, int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); #endif - return detail::copy_explode(traits.opargs_, tuple_seq{}, - make_tuple(dst_ptr), seq<0>{}, - src_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); } } }; @@ -337,8 +339,9 @@ struct Copy_Traits blockIdx.x, blockIdx.y, blockIdx.z, int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); #endif - return detail::copy_explode(make_tuple(desc_ptr, src_ptr), seq<0,1>{}, - dst_coord, tuple_seq{}); + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); } }; @@ -1278,7 +1281,7 @@ tma_partition(Copy_Atom const& copy_atom, // Factor out the single-instrucion portion Layout tma_layout_v = make_layout(Int::NumValSrc>{}); auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); - + // Append with _ until we cover all Rest... modes auto glayout_V = append>(layout_V, _); auto slayout_V = append>(layout_V, _); @@ -1288,39 +1291,45 @@ tma_partition(Copy_Atom const& copy_atom, #if 0 if (thread0()) { - print("gtensor : "); print(gtensor); print("\n"); - print("stensor : "); print(stensor); print("\n"); + print("cta_coord : "); print(cta_coord); print("\n"); + print("cta_layout : "); print(cta_layout); print("\n"); + print("gtensor : "); print(gtensor); print("\n"); + print("stensor : "); print(stensor); print("\n"); print("layout_V : "); print(layout_V); print("\n"); print("gtensor_v : "); print(gtensor_v); print("\n"); print("stensor_v : "); print(stensor_v); print("\n"); } #endif - // Restride the cta-into-tma-instr layout - Layout tma_layout_t = composition(make_layout(Int<1>{}, shape_div(size(tma_layout_v), cosize(cta_layout))), cta_layout); - auto tma_layout_tv = make_tile(make_tile(make_layout(tma_layout_t, tma_layout_v), _)); + // Offset inside the TMA-mode for the multicast + auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); + auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); + auto scoord = append(multicast_coord, Int<0>{}); + auto gcoord = append(multicast_coord, Int<0>{}); - // Append with _ until we cover all Rest... modes - auto gtma_layout_tv = append>(tma_layout_tv, _); - auto stma_layout_tv = append>(tma_layout_tv, _); + Tensor gresult = domain_offset(gcoord, gtensor_v); + Tensor sresult = domain_offset(scoord, stensor_v); - // Transform TMA mode - Tensor gtensor_tv = gtensor_v.compose(gtma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...) - Tensor stensor_tv = stensor_v.compose(stma_layout_tv); // (((Thr,Frg),TMA_Iter), Rest...) + return cute::make_tuple(gresult, sresult); +} -#if 0 - if (thread0()) { - print("tma_layout_tv : "); print(tma_layout_tv); print("\n"); - print("gtensor_tv : "); print(gtensor_tv); print("\n"); - print("stensor_tv : "); print(stensor_tv); print("\n"); +// TMA Multicast Masks Calculation +template +CUTE_HOST_DEVICE constexpr +auto +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + auto cta_coord_slicer = replace(cta_coord_vmnk, _); + auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk); + // Get the instruction code + uint16_t mcast_mask = 0; + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); } -#endif - - auto c = make_coord(make_coord(make_coord(cta_coord, _), _)); - auto c_s = append>(c, _); - auto c_g = append>(c, _); - - return cute::make_tuple(group_modes<0,2>(gtensor_tv(c_g)), group_modes<0,2>(stensor_tv(c_s))); + // Shift by the instruction's elected block rank (dynamic) + mcast_mask <<= elected_cta; + return mcast_mask; } } // end namespace cute diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 674e3519e8..9e5c93f2ea 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -715,6 +715,7 @@ print(MMA_Atom> const&) using Atom = MMA_Atom>; print("MMA_Atom\n"); print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" Shape_MNK: "); print(typename Atom::Shape_MNK{}); print("\n"); print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp index 8c090936c2..34275831b8 100644 --- a/include/cute/atom/mma_traits.hpp +++ b/include/cute/atom/mma_traits.hpp @@ -149,17 +149,17 @@ mma_unpack(MMA_Traits const& traits, //CUTE_STATIC_ASSERT_V(size(rC) == Int{}); if constexpr (detail::supports_output_scaling::value) { - detail::explode_with_d_scaling(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}, - traits.accumulate_); + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); } else { detail::explode(MMA_Op::fma, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); } } else { @@ -169,19 +169,19 @@ mma_unpack(MMA_Traits const& traits, CUTE_STATIC_ASSERT_V(size(rD) == Int{}); CUTE_STATIC_ASSERT_V(size(rC) == Int{}); if constexpr (detail::supports_output_scaling::value) { - detail::explode_with_d_scaling(MMA_Op::fma, + detail::explode(MMA_Op::fma, rD, make_int_sequence{}, rA, make_int_sequence{}, rB, make_int_sequence{}, rC, make_int_sequence{}, - traits.accumulate_); + &(traits.accumulate_), seq<0>{}); } else { detail::explode(MMA_Op::fma, - rD, make_int_sequence{}, - rA, make_int_sequence{}, - rB, make_int_sequence{}, - rC, make_int_sequence{}); + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); } } } @@ -198,7 +198,7 @@ template const& traits, - Tensor && D, + Tensor && D, Tensor const& A, Tensor const& B, Tensor const& C) diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 3bbcd1fb35..db6f0fc2e9 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -208,7 +208,7 @@ make_gmma_desc(Tensor const& tensor) // Start address (4LSB not included) uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); - desc.bitfield.start_address_ = start_address >> 4; + desc.bitfield.start_address_ = static_cast(start_address >> 4); constexpr uint8_t base_offset = 0; desc.bitfield.base_offset_ = base_offset; diff --git a/include/cute/config.hpp b/include/cute/config.hpp index 941f60d7ad..35d4f8fdf0 100644 --- a/include/cute/config.hpp +++ b/include/cute/config.hpp @@ -91,7 +91,7 @@ // It's harmless to use the macro for other GCC versions or other // compilers, but it has no effect. #if ! defined(CUTE_GCC_UNREACHABLE) -# if defined(__clang__) || defined(__GNUC__) +# if defined(__GNUC__) # define CUTE_GCC_UNREACHABLE __builtin_unreachable() # else # define CUTE_GCC_UNREACHABLE diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 110e233a36..f8ca467181 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -325,10 +325,21 @@ CUTE_HOST_DEVICE constexpr auto ceil_div(IntTupleA const& a, IntTupleB const& b) { - if constexpr (is_tuple::value && is_tuple::value) { - static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); - constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 - return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { // tuple int + auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return ceil_div(a, product(b)); } else { return (a + b - Int<1>{}) / b; } diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 71c4ce138b..b7517a67ce 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -418,8 +418,8 @@ make_layout_like(Layout const& layout) // Make a compact layout with the same shape as @a layout // and strides following the order induced by @a layout.stride(), // except mode-0 is always stride-1 and generated column-major. -// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms -// so this generates the 0th mode with LayoutLeft regardless of the reference layout. +// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms so this +// generates the 0th mode with LayoutLeft (preserving stride-0s) regardless of the reference layout template CUTE_HOST_DEVICE constexpr auto @@ -427,7 +427,8 @@ make_fragment_like(Layout const& layout) { constexpr int R = Layout::rank; if constexpr (R > 1 && is_static::value) { - return tiled_product(make_layout(shape<0>(layout)), + return tiled_product(make_layout(get<0>(layout.shape()), + compact_col_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); } else { return make_layout(layout.shape()); @@ -757,7 +758,8 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, } else if constexpr (is_constant<1, NewShape>::value) { // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); - } else if constexpr (is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + } else if constexpr (is_static(new_shape))>::value && + is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { // Merge modes because the shapes and strides match return bw_coalesce(old_shape, old_stride, replace_front(new_shape, get(old_shape) * get<0>(new_shape)), @@ -772,6 +774,45 @@ bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, CUTE_GCC_UNREACHABLE; } +// cute::coalesce promises to not change the Layout as a function from integers to codomain. +// It accomplishes this inside of the Layout's domain, but not always outside of the domain. +// Example: (_4,_1):(_1,_0) coalesces to _4:_1. +// detail::coalesce_x preserves the Layout function inside its domain and outside. +// +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i, @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + if constexpr (is_constant<1, decltype(get(flat_shape))>::value) { + return detail::bw_coalesce(flat_shape, flat_stride, Int<2>{}, get(flat_stride)); + } else { + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); + } +} + +// Apply coalesce_x at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return cute::transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce_x(l,t); }); + } else { + return coalesce_x(layout); + } + + CUTE_GCC_UNREACHABLE; +} + } // end namespace detail // "Simplify" the layout by combining modes that are possible to combine @@ -807,6 +848,25 @@ coalesce(Layout const& layout, IntTuple const& trg_profile) CUTE_GCC_UNREACHABLE; } +// Combine static and dynamic modes of a shape. +// @post size(@a result) == size(@a shape) +// @post depth(@a result) <= 1 +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Shape const& shape) +{ + static_assert(is_integral::value || is_tuple::value); + + return cute::fold_first(flatten(shape), [](auto const& init, auto const& a) { + if constexpr (is_static::value == is_static::value) { + return replace_back(init, back(init) * a); // Both static or both dynamic, coalesce and replace + } else { + return append(init, a); // Can't coalesce, so append + } + }); +} + // Replace the modes in layout that have a 0-stride with a 1-size template CUTE_HOST_DEVICE constexpr @@ -918,70 +978,64 @@ template CUTE_HOST_DEVICE constexpr auto -composition_impl(Layout const& lhs, +composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, RShape const& rhs_shape, RStride const& rhs_stride) { if constexpr (is_tuple::value) { // Apply the right-distributivity of Layout composition - return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { return composition_impl(lhs, s, d); }); + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { + return composition_impl(lhs_shape, lhs_stride, s, d); + }); } else if constexpr (is_scaled_basis::value) { // Special case for a ScaledBasis stride - return composition_impl(get(lhs), rhs_shape, rhs_stride.value()); + return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride), + rhs_shape, basis_value(rhs_stride)); } else - if constexpr (is_integral::value) { - // Integral Rstride (and RShape) - - // NOTE: Should only flatten once for efficiency - auto flat_shape = flatten(lhs.shape()); - [[maybe_unused]] auto flat_stride = flatten(lhs.stride()); - [[maybe_unused]] constexpr int R = rank(flat_shape); - - if constexpr (is_constant<0, RStride>::value) { - // Special case shortcut for any static stride-0 - return Layout{rhs_shape, rhs_stride}; - } else - if constexpr (is_integral::value) { - // Special case shortcut for any integral LShape - auto result_stride = rhs_stride * flat_stride; - return Layout{rhs_shape, result_stride}; - } else - if constexpr (is_constant<1, RStride>::value) { - // Special case shortcut for any static stride-1 - auto result_shape_0 = take<0,R-1>(flat_shape); - - // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), - [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); - }); - - // Jump into coalesce and append (rest_shape, get(lhs.stride()) - return detail::bw_coalesce(result_shape_1, flat_stride, rest_shape, get(flat_stride)); - } else - { - // General case - auto result_shape_0 = take<0,R-1>(flat_shape); - auto result_stride_0 = take<0,R-1>(flat_stride); - - // Divide out the rhs_stride from the lhs.shape() - auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), - [] (auto const& init, auto const& di) { - return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); - }); - - // Apply any lhs.shape() changes to the stride - auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); - - // Mod out the rhs_shape from the lhs.shape() - auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), - [] (auto const& init, auto const& si) { - return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); - }); - - // Jump into coalesce and append (rest_shape, rest_stride * get(lhs.stride()) - return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(flat_stride)); - } + if constexpr (is_constant<0, RStride>::value) { + // Special case shortcut for any static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { + // Special case shortcut for any integral LShape + return Layout{rhs_shape, rhs_stride * lhs_stride}; + } else + if constexpr (is_constant<1, RStride>::value) { + // Special case shortcut for any static stride-1 + constexpr int R = rank_v; + auto result_shape_0 = take<0,R-1>(lhs_shape); + + // Mod out the rhs_shape from the lhs_shape + auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, get(lhs_stride)) + return detail::bw_coalesce(result_shape_1, lhs_stride, rest_shape, get(lhs_stride)); + } else { + // General case: integral RShape and RStride, tuple LShape and LStride + constexpr int R = rank_v; + auto result_shape_0 = take<0,R-1>(lhs_shape); + auto result_stride_0 = take<0,R-1>(lhs_stride); + + // Divide out the rhs_stride from the lhs_shape + auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), + [] (auto const& init, auto const& di) { + return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); + }); + + // Apply any lhs_shape changes to the stride + auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); + + // Mod out the rhs_shape from the lhs_shape + auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, rest_stride * get(lhs_stride)) + return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(lhs_stride)); } CUTE_GCC_UNREACHABLE; @@ -996,7 +1050,9 @@ auto composition(Layout const& lhs, Layout const& rhs) { - return detail::composition_impl(lhs, rhs.shape(), rhs.stride()); + auto coprofile = repeat_like(decltype(coshape(rhs)){}, Int<0>{}); + auto flat_lhs = detail::coalesce_x(lhs, coprofile); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride()); } template @@ -1012,7 +1068,8 @@ composition(Layout const& lhs, } else if constexpr (is_underscore::value) { return lhs; } else if constexpr (is_integral::value) { - return detail::composition_impl(lhs, rhs, Int<1>{}); + auto flat_lhs = detail::coalesce_x(lhs); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs, Int<1>{}); } CUTE_GCC_UNREACHABLE; @@ -1032,14 +1089,14 @@ composition(Layout const& lhs, namespace detail { // @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). -template +template CUTE_HOST_DEVICE constexpr auto -complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) +complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget) { if constexpr (is_constant<0, Stride>::value) { // Special case for irreducible rank-1 stride-0 layout - return make_layout(cosize_hi); + return make_layout(coalesce(cotarget)); } else { // General case constexpr int R = rank_v; @@ -1055,28 +1112,30 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) { auto [shape, stride, result_shape, result_stride] = init; auto min_stride = cute::min(stride); - auto min_idx = find(stride, min_stride); + auto min_idx = cute::find(stride, min_stride); auto new_shape = min_stride / get(result_stride); - auto new_stride = get(shape) * min_stride; + auto new_stride = min_stride * get(shape); static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); return cute::make_tuple(remove(shape), // Remove the min_idx from shape remove(stride), // Remove the min_idx from stride append(result_shape , new_shape ), // new shape = min_stride / last_stride - append(result_stride, new_stride)); // new stride = curr_shape * min_stride + append(result_stride, new_stride)); // new stride = min_stride * curr_shape }); // Append the last shape mode - auto new_shape = get<0>(stride_) / get(result_stride); + auto new_shape = get<0>(stride_) / get(result_stride); // new shape = min_stride / last_stride static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); - auto result_shape = append(result_shape_, new_shape); // new shape = min_stride / last_stride + auto result_shape = append(result_shape_, new_shape); // Compute the rest_shape and rest_stride - auto rest_stride = get<0>(shape_) * get<0>(stride_); - auto rest_shape = ceil_div(cosize_hi, rest_stride); + auto new_stride = get<0>(stride_) * get<0>(shape_); // new stride = min_stride * curr_shape + auto rest_shape = coalesce(ceil_div(cotarget, new_stride)); + auto rest_stride = compact_col_major(rest_shape, new_stride); - // Jump into coalesce and append (rest_shape, rest_stride) - return detail::bw_coalesce(result_shape, result_stride, rest_shape, rest_stride); + // Coalesce and append (rest_shape, rest_stride) + return coalesce(make_layout(make_shape (result_shape , rest_shape ), + make_stride(result_stride, rest_stride))); } CUTE_GCC_UNREACHABLE; @@ -1084,14 +1143,13 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) } // end namespace detail -template +template CUTE_HOST_DEVICE constexpr auto -complement(Layout const& layout, CoSizeHi const& cosize_hi) +complement(Layout const& layout, CoTarget const& cotarget) { - static_assert(cute::is_integral::value, "Expected integral codomain size in complement."); auto filter_layout = filter(layout); - return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize_hi); + return detail::complement(filter_layout.shape(), filter_layout.stride(), shape(cotarget)); } template @@ -1365,7 +1423,7 @@ auto logical_divide(Layout const& layout, Layout const& tiler) { - return composition(layout, make_layout(tiler, complement(tiler, size(layout)))); + return composition(layout, make_layout(tiler, complement(tiler, shape(layout)))); } template diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 93c60898d7..3dbd2cd939 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -392,12 +392,12 @@ composition(Layout const& a, // complement // -template +template CUTE_HOST_DEVICE constexpr auto -complement(ComposedLayout const& layout, CoSizeHi const& cosize_hi) +complement(ComposedLayout const& layout, CoTarget const& cotarget) { - return complement(layout.layout_b(), cosize_hi); + return complement(layout.layout_b(), cotarget); } template @@ -610,7 +610,7 @@ recast_layout(ComposedLayout const& layout) else if constexpr (scale::num == 1) { return downcast(layout); } - else if constexpr (scale::den == 1) { + else if constexpr (scale::den == 1) { return upcast(layout); } else { diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index 27d1cf8e38..651ff8e887 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -73,25 +73,17 @@ make_arithmetic_tuple(T const&... t) { return ArithmeticTuple(t...); } -template +template CUTE_HOST_DEVICE constexpr auto -as_arithmetic_tuple(tuple const& t) { - return ArithmeticTuple(t); -} - -template ::value)> -CUTE_HOST_DEVICE constexpr -T const& as_arithmetic_tuple(T const& t) { - return t; -} - -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ArithmeticTuple const& t) { - return t; + if constexpr (is_tuple::value) { + return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); }, + [](auto const&... a){ return make_arithmetic_tuple(a...); }, + tuple_seq{}); + } else { + return t; + } } // @@ -289,6 +281,26 @@ basis_get(SB const& e, Tuple const& t) namespace detail { +template +CUTE_HOST_DEVICE constexpr +auto +to_atuple_i(T const& t, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-N+1 ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq{}); +} + +namespace detail { + template struct Basis; @@ -315,71 +327,6 @@ struct Basis { template using E = typename detail::Basis::type; -namespace detail { - -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(T const& t, seq, seq) { - return make_arithmetic_tuple((void(I),Int<0>{})..., t, (void(J),Int<0>{})...); -} - -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ArithmeticTuple const& t, seq, seq) { - return make_arithmetic_tuple(get(t)..., (void(J),Int<0>{})...); -} - -} // end namespace detail - -// Turn a ScaledBases into a rank-M ArithmeticTuple -// with N prefix 0s: (_0,_0,...N...,_0,T,_0,...,_0,_0) -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ScaledBasis const& t) { - static_assert(M > N, "Mismatched ranks"); - return detail::as_arithmetic_tuple(t.value(), make_seq{}, make_seq{}); -} - -// Turn a ScaledBases into a rank-N ArithmeticTuple -// with N prefix 0s: (_0,_0,...N...,_0,T) -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ScaledBasis const& t) { - return as_arithmetic_tuple(t); -} - -// Turn an ArithmeticTuple into a rank-M ArithmeticTuple -// with postfix 0s: (t0,t1,t2,...,_0,...,_0,_0) -template -CUTE_HOST_DEVICE constexpr -auto -as_arithmetic_tuple(ArithmeticTuple const& t) { - static_assert(M >= sizeof...(T), "Mismatched ranks"); - return detail::as_arithmetic_tuple(t, make_seq{}, make_seq{}); -} - -template -CUTE_HOST_DEVICE constexpr -auto -safe_div(ScaledBasis const& b, U const& u) -{ - auto t = safe_div(b.value(), u); - return ScaledBasis{t}; -} - -template -CUTE_HOST_DEVICE constexpr -auto -shape_div(ScaledBasis const& b, U const& u) -{ - auto t = shape_div(b.value(), u); - return ScaledBasis{t}; -} - template CUTE_HOST_DEVICE constexpr auto @@ -387,8 +334,7 @@ make_basis_like(Shape const& shape) { if constexpr (is_integral::value) { return Int<1>{}; - } - else { + } else { // Generate bases for each rank of shape return transform(tuple_seq{}, shape, [](auto I, auto si) { // Generate bases for each rank of si and add an i on front @@ -408,6 +354,28 @@ make_basis_like(Shape const& shape) CUTE_GCC_UNREACHABLE; } +// +// Arithmetic +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(ScaledBasis const& b, U const& u) +{ + auto t = shape_div(b.value(), u); + return ScaledBasis{t}; +} + // Equality template CUTE_HOST_DEVICE constexpr @@ -432,7 +400,7 @@ operator==(T const&, ScaledBasis const&) { } // Abs -template +template CUTE_HOST_DEVICE constexpr auto abs(ScaledBasis const& e) { @@ -440,7 +408,7 @@ abs(ScaledBasis const& e) { } // Multiplication -template +template CUTE_HOST_DEVICE constexpr auto operator*(A const& a, ScaledBasis const& e) { @@ -448,7 +416,7 @@ operator*(A const& a, ScaledBasis const& e) { return ScaledBasis{r}; } -template +template CUTE_HOST_DEVICE constexpr auto operator*(ScaledBasis const& e, B const& b) { @@ -457,44 +425,25 @@ operator*(ScaledBasis const& e, B const& b) { } // Addition -template -CUTE_HOST_DEVICE constexpr -auto -operator+(ScaledBasis const& t, ArithmeticTuple const& u) { - constexpr int R = cute::max(N+1, int(sizeof...(U))); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -} - -template -CUTE_HOST_DEVICE constexpr -auto -operator+(ArithmeticTuple const& t, ScaledBasis const& u) { - constexpr int R = cute::max(int(sizeof...(T)), M+1); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); -} - -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, tuple const& u) { - constexpr int R = cute::max(N+1, int(sizeof...(U))); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +operator+(ScaledBasis const& t, ScaledBasis const& u) { + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(tuple const& t, ScaledBasis const& u) { - constexpr int R = cute::max(int(sizeof...(T)), M+1); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + return as_arithmetic_tuple(t) + u; } -template +template CUTE_HOST_DEVICE constexpr auto -operator+(ScaledBasis const& t, ScaledBasis const& u) { - constexpr int R = cute::max(N+1,M+1); - return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + return t + as_arithmetic_tuple(u); } template diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp index 8cc3625330..5113719dbd 100644 --- a/include/cute/numeric/complex.hpp +++ b/include/cute/numeric/complex.hpp @@ -56,10 +56,10 @@ fma(complex & d, complex const& b, complex const& c) { - d.real(fma( a.real(), b.real(), c.real())); - d.imag(fma( a.real(), b.imag(), c.imag())); - d.real(fma(-a.imag(), b.imag(), d.real())); - d.imag(fma( a.imag(), b.real(), d.imag())); + fma(d.real(), a.real(), b.real(), c.real()); + fma(d.imag(), a.real(), b.imag(), c.imag()); + fma(d.real(), -a.imag(), b.imag(), d.real()); + fma(d.imag(), a.imag(), b.real(), d.imag()); } /// Fused multiply-add for triplets diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 5647f97c12..604477a0d3 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -41,7 +41,6 @@ #include #include -#include namespace cute { @@ -102,6 +101,8 @@ template // Found the gmem struct is_gmem> : true_type {}; template // Recurse on ::iterator, if possible struct is_gmem> : is_gmem {}; +template +constexpr bool is_gmem_v = is_gmem

::value; // Idempotent gmem tag on an iterator template @@ -163,6 +164,8 @@ template // Found the smem struct is_smem> : true_type {}; template // Recurse on ::iterator, if possible struct is_smem> : is_smem {}; +template +constexpr bool is_smem_v = is_smem

::value; // Idempotent smem tag on an iterator template @@ -224,6 +227,8 @@ template struct is_rmem : bool_constant::value || is_smem::value)> {}; template struct is_rmem> : true_type {}; +template +constexpr bool is_rmem_v = is_rmem

::value; // Idempotent rmem tag on an iterator template diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp index aa917d9f7d..08751eb169 100644 --- a/include/cute/pointer_flagged.hpp +++ b/include/cute/pointer_flagged.hpp @@ -89,7 +89,7 @@ downcast(ComposedLayout,Layout> const& layout) // Conversion with swizzle_layout // -template +template CUTE_HOST_DEVICE auto as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) @@ -129,6 +129,14 @@ as_position_independent_swizzle_tensor(Tensor&& tensor) // // Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_layout(ComposedLayout,Layout> const& layout) +{ + print_layout(as_position_independent_swizzle_layout(layout)); +} + template CUTE_HOST_DEVICE void diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index 28d3ee67a9..71ace9a81c 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -316,6 +316,8 @@ template struct is_tensor : false_type {}; template struct is_tensor> : true_type {}; +template +constexpr bool is_tensor_v = is_tensor::value; // Customization point for creation of owning and non-owning Tensors template @@ -1082,7 +1084,6 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const #include #include - // // Tensor Algorithms // diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index a8cab903d1..56cc814ec3 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -101,6 +101,9 @@ using CUTE_STL_NAMESPACE::is_lvalue_reference_v; using CUTE_STL_NAMESPACE::is_reference; using CUTE_STL_NAMESPACE::is_trivially_copyable; +using CUTE_STL_NAMESPACE::is_convertible; +using CUTE_STL_NAMESPACE::is_convertible_v; + using CUTE_STL_NAMESPACE::is_same; using CUTE_STL_NAMESPACE::is_same_v; @@ -247,4 +250,15 @@ is_valid(F&&, Args&&...) { return detail::is_valid_impl(int{}); } +template class True, template class False> +struct conditional_template { + template + using type = True; +}; + +template class True, template class False> +struct conditional_template { + template + using type = False; +}; } // end namespace cute diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 48ad7cca98..dcaa1093c8 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -33,16 +33,6 @@ and is safe to use in a union. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" #include "cutlass/functional.h" @@ -57,7 +47,7 @@ template < int N, bool RegisterSized = sizeof_bits::value >= 32 > -class Array; +struct Array; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -90,8 +80,7 @@ template < typename T, int N > -class Array { -public: +struct Array { /// Storage type using Storage = T; @@ -101,10 +90,10 @@ class Array { /// Number of storage elements //static std::size_t const kStorageElements = N; - static size_t const kStorageElements = N; + static constexpr size_t kStorageElements = N; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; // // C++ standard members @@ -346,26 +335,9 @@ class Array { } }; -private: - /// Internal storage Storage storage[kElements]; -public: - - #if 0 - CUTLASS_HOST_DEVICE - Array() { } - - CUTLASS_HOST_DEVICE - Array(Array const &x) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElements; ++i) { - storage[i] = x.storage[i]; - } - } - #endif - /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { @@ -530,39 +502,25 @@ class Array { template CUTLASS_HOST_DEVICE Array make_Array(Element x) { - Array m; - m[0] = x; - return m; + return {x}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y) { - Array m; - m[0] = x; - m[1] = y; - return m; + return {x,y}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y, Element z) { - Array m; - m[0] = x; - m[1] = y; - m[2] = z; - return m; + return {x,y,z}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y, Element z, Element w) { - Array m; - m[0] = x; - m[1] = y; - m[2] = z; - m[3] = w; - return m; + return {x,y,z,w}; } @@ -1104,6 +1062,58 @@ struct square_and_plus> { } }; +/// Inverse-square-root +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + Array result; + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + return result; + } +}; + +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & a) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = h2rsqrt(a_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half d_residual = hrsqrt(a_residual_ptr[N - 1]); + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + + #endif + + return result; + } +}; + /// Fused multiply-add-relu0 template struct multiply_add_relu0, Array, Array> { @@ -2513,7 +2523,6 @@ struct bit_xor> { } }; - ///////////////////////////////////////////////////////////////////////////////////////////////// // Operator overloads ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -2525,6 +2534,20 @@ Array operator+(Array const &lhs, Array const &rhs) { return op(lhs, rhs); } +template +CUTLASS_HOST_DEVICE +Array operator+(T const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, T const &rhs) { + plus> op; + return op(lhs, rhs); +} + template CUTLASS_HOST_DEVICE Array operator-(Array const &lhs, Array const &rhs) { diff --git a/include/cutlass/array_planar_complex.h b/include/cutlass/array_planar_complex.h index 9fcd4d18e1..2dd8aa84e1 100644 --- a/include/cutlass/array_planar_complex.h +++ b/include/cutlass/array_planar_complex.h @@ -51,13 +51,12 @@ struct ArrayPlanarComplex { using Element = Element_; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; /// Underlying Fragment of real-valued elemenets - using ArrayReal = Array; + using ArrayReal = cutlass::Array; public: - /// Fragment of real-valued elements representing the real part ArrayReal real; @@ -65,19 +64,6 @@ struct ArrayPlanarComplex { ArrayReal imag; public: - - /// Ctor - CUTLASS_HOST_DEVICE - ArrayPlanarComplex() { } - - /// Ctor - CUTLASS_HOST_DEVICE - ArrayPlanarComplex( - ArrayReal const &real_, - ArrayReal const &imag_ - ): - real(real_), imag(imag_) { } - /// Sets the array to zero efficiently CUTLASS_HOST_DEVICE void clear() { @@ -93,7 +79,7 @@ template CUTLASS_HOST_DEVICE ArrayPlanarComplex make_ArrayPlanarComplex(Array const &real, Array const &imag) { - return ArrayPlanarComplex(real, imag); + return ArrayPlanarComplex{real, imag}; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index 25bbe355df..eb77a9310e 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -32,15 +32,6 @@ \brief Statically sized array of elements that accommodates all CUTLASS-supported numeric types and is safe to use in a union. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once @@ -57,10 +48,8 @@ template < typename T, int N > -class Array { -public: - - static int const kSizeBits = sizeof_bits::value * N; +struct Array { + static constexpr int kSizeBits = sizeof_bits::value * N; /// Storage type using Storage = typename platform::conditional< @@ -77,16 +66,16 @@ class Array { using Element = T; /// Number of logical elements per stored object - static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; + static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; /// Number of storage elements - static size_t const kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; + static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; /// Bitmask for covering one item - static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); + static constexpr Storage kMask = ((Storage(1) << sizeof_bits::value) - 1); // // C++ standard members with pointer types removed @@ -105,16 +94,14 @@ class Array { /// Reference object inserts or extracts sub-byte items class reference { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - /// Default ctor - CUTLASS_HOST_DEVICE - reference(): ptr_(nullptr), idx_(0) { } + reference() = default; /// Ctor CUTLASS_HOST_DEVICE @@ -123,11 +110,38 @@ class Array { /// Assignment CUTLASS_HOST_DEVICE reference &operator=(T x) { + // `*ptr_ & kUpdateMask` will read ptr_ before write to it + // This means code pattern like + // + // ```cpp + // Array result; + // result[0] = xxx; + // ``` + // + // Will leads to compiler warning on use of unintialized member variable. Although we know + // this read of uninitialized member variable is harmeless. + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wuninitialized" +#elif defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuninitialized" +# pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + Storage item = (reinterpret_cast(x) & kMask); Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); + *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); +#if defined(__clang__) +# pragma clang diagnostic pop +#elif defined(__GNUC__) +# pragma GCC diagnostic pop +#endif + return *this; } @@ -160,16 +174,14 @@ class Array { class const_reference { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - /// Default ctor - CUTLASS_HOST_DEVICE - const_reference(): ptr_(nullptr), idx_(0) { } + const_reference() = default; /// Ctor CUTLASS_HOST_DEVICE @@ -209,15 +221,14 @@ class Array { class iterator { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - iterator(): ptr_(nullptr), idx_(0) { } + iterator() = default; CUTLASS_HOST_DEVICE iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -288,15 +299,14 @@ class Array { class const_iterator { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - const_iterator(): ptr_(nullptr), idx_(0) { } + const_iterator() = default; CUTLASS_HOST_DEVICE const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -367,15 +377,14 @@ class Array { class reverse_iterator { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - reverse_iterator(): ptr_(nullptr), idx_(0) { } + reverse_iterator() = default; CUTLASS_HOST_DEVICE reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -385,40 +394,19 @@ class Array { class const_reverse_iterator { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - const_reverse_iterator(): ptr_(nullptr), idx_(0) { } + const_reverse_iterator() = default; CUTLASS_HOST_DEVICE const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } }; -private: - - /// Internal storage - Storage storage[kStorageElements] = {Storage{0}}; - -public: - - #if 0 - CUTLASS_HOST_DEVICE - Array() { } - - CUTLASS_HOST_DEVICE - Array(Array const &x) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(kStorageElements); ++i) { - storage[i] = x.storage[i]; - } - } - #endif - /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { @@ -489,7 +477,6 @@ class Array { return storage; } - CUTLASS_HOST_DEVICE constexpr bool empty() const { return !kElements; @@ -560,10 +547,9 @@ class Array { return const_reverse_iterator(storage); } - // - // Comparison operators - // - +private: + /// Internal storage + Storage storage[kStorageElements]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index 75cadbfa43..c2e6cb0de6 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -34,16 +34,6 @@ 8 bits of exponent and 7 bit of mantissa. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp index 0996fff87d..36f603e31a 100644 --- a/include/cutlass/cluster_launch.hpp +++ b/include/cutlass/cluster_launch.hpp @@ -229,7 +229,7 @@ struct ClusterLaunchParams { /// void const* kernel_ptr = /// const_cast(reinterpret_cast( /// &kernel)); -/// auto status = launch_on_cluster( +/// auto status = launch_kernel_on_cluster( /// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, /// kernel_ptr, x, y, z); /// @endcode @@ -243,10 +243,10 @@ launch_kernel_on_cluster(const ClusterLaunchParams& params, // the parameters as an array of raw pointers. if constexpr (sizeof...(Args) == 0) { return cutlass::ClusterLauncher::launch( - params.grid_dims, - params.cluster_dims, + params.grid_dims, + params.cluster_dims, params.block_dims, - params.smem_size_in_bytes, + params.smem_size_in_bytes, params.cuda_stream, kernel_ptr, nullptr); } @@ -255,12 +255,12 @@ launch_kernel_on_cluster(const ClusterLaunchParams& params, detail::checked_addressof(std::forward(args))... }; return cutlass::ClusterLauncher::launch( - params.grid_dims, - params.cluster_dims, + params.grid_dims, + params.cluster_dims, params.block_dims, - params.smem_size_in_bytes, + params.smem_size_in_bytes, params.cuda_stream, - kernel_ptr, + kernel_ptr, kernel_params); } } diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 32cfa5f76b..1f92b667e6 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -28,15 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ #pragma once diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index ec86421976..d2e8952998 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -44,15 +44,6 @@ Map tensor sizes (Conv2d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) Map tensor problem sizes (Conv2d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once @@ -247,9 +238,10 @@ struct Conv2dProblemSize { /// Returns filter extent as Tensor4DCoord CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord filter_extent() const { + cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const { - return cutlass::Tensor4DCoord ({K, R, S, C / groups}); + return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups}) + : cutlass::Tensor4DCoord ({K, R, S, C / groups}); } /// Returns output extent as Tensor4DCoord @@ -340,6 +332,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( problem_size.K, problem_size.R * problem_size.S * problem_size.C / problem_size.groups ); + case Operator::kDeconv: case Operator::kDgrad: return gemm::GemmCoord( problem_size.N * problem_size.H * problem_size.W, @@ -404,6 +397,7 @@ int implicit_gemm_k_iterations( iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); break; + case Operator::kDeconv: case Operator::kDgrad: elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); @@ -505,6 +499,7 @@ int implicit_gemm_k_iterations_per_channel( iterations = problem_size.R * problem_size.S; break; + case Operator::kDeconv: case Operator::kDgrad: iterations = problem_size.R * problem_size.S; break; @@ -526,6 +521,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); default : break; @@ -540,6 +536,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); default : break; @@ -554,6 +551,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); default : break; @@ -568,6 +566,7 @@ int64_t implicit_gemm_tensor_a_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); default : break; @@ -582,6 +581,7 @@ int64_t implicit_gemm_tensor_b_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); default : break; @@ -596,6 +596,7 @@ int64_t implicit_gemm_tensor_c_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); default : break; diff --git a/include/cutlass/conv/conv3d_problem_size.h b/include/cutlass/conv/conv3d_problem_size.h index 56b164232a..9a9514f2d8 100644 --- a/include/cutlass/conv/conv3d_problem_size.h +++ b/include/cutlass/conv/conv3d_problem_size.h @@ -44,15 +44,6 @@ Map tensor sizes (Conv3d -> ImplicitGemm) : implicit_gemm_tensor_[a|b|c]_size(ConvolutionOperator) Map tensor problem sizes (Conv3d -> ImplicitGemm): implicit_gemm_problem_size(ConvolutionOperator) */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once @@ -277,9 +268,10 @@ struct Conv3dProblemSize : public Conv2dProblemSize { /// Returns filter extent as Tensor5DCoord CUTLASS_HOST_DEVICE - cutlass::Tensor5DCoord filter_extent() const { + cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const { - return cutlass::Tensor5DCoord ({K, T, R, S, C}); + return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K}) + : cutlass::Tensor5DCoord ({K, T, R, S, C}); } /// Returns output extent as Tensor5DCoord @@ -351,6 +343,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( problem_size.K, problem_size.T * problem_size.R * problem_size.S * problem_size.C ); + case Operator::kDeconv: case Operator::kDgrad: return gemm::GemmCoord( problem_size.N * problem_size.D * problem_size.H * problem_size.W, @@ -387,7 +380,8 @@ int implicit_gemm_k_iterations( elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); break; - + + case Operator::kDeconv: case Operator::kDgrad: elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); @@ -430,6 +424,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); default : break; @@ -444,6 +439,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); default : break; @@ -458,6 +454,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); default : break; @@ -472,6 +469,7 @@ int64_t implicit_gemm_tensor_a_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); default : break; @@ -486,6 +484,7 @@ int64_t implicit_gemm_tensor_b_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); default : break; @@ -500,6 +499,7 @@ int64_t implicit_gemm_tensor_c_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); default : break; diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index a61f573e0e..243ee269dd 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -70,16 +70,6 @@ Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" @@ -98,7 +88,8 @@ namespace conv { enum class Operator { kFprop, kDgrad, - kWgrad + kWgrad, + kDeconv }; /// Distinguishes convolution from cross correlation diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp index 69cfbabae9..603c47e84b 100644 --- a/include/cutlass/conv/device/conv_universal_adapter.hpp +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -144,6 +144,11 @@ class ConvUniversalAdapter public: + /// Access the Params structure + Params const& params() const { + return params_; + } + /// Determines whether the conv can execute the given problem. static Status can_implement(Arguments const& args) { @@ -323,13 +328,12 @@ class ConvUniversalAdapter } } else { - CUTLASS_ASSERT(cuda_adapter == nullptr); void const* kernel = (void const*) device_kernel; - - launch_result = ClusterLauncher::launch( - grid, cluster, block, smem_size, stream, kernel, kernel_params); - + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + } } } else { diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index d2319f2c66..62c7e8715d 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -153,7 +153,7 @@ class ImplicitGemmConvolution { if (kConvolutionalOperator == conv::Operator::kFprop) { if (args.problem_size.K % kAlignmentC) return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + } else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) { if (args.problem_size.C % kAlignmentC) return Status::kErrorMisalignedOperand; } else if (kConvolutionalOperator == conv::Operator::kWgrad) { @@ -161,16 +161,16 @@ class ImplicitGemmConvolution { return Status::kErrorMisalignedOperand; } - // check for unsupported problem sizes for strided dgrad implementation - if (kConvolutionalOperator == conv::Operator::kDgrad && + // check for unsupported problem sizes for strided dgrad / deconv implementation + if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && kStrideSupport == conv::StrideSupport::kStrided) { - // split-k (serial or parallel) is not supported for strided dgrad + // split-k (serial or parallel) is not supported for strided dgrad / deconv if(args.problem_size.split_k_slices > 1) { return Status::kErrorNotSupported; } - - // dilation > {1x1} is not supported for strided dgrad + + // dilation > {1x1} is not supported for strided dgrad / deconv if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { return Status::kErrorNotSupported; } diff --git a/include/cutlass/conv/kernel/default_conv2d.h b/include/cutlass/conv/kernel/default_conv2d.h index 51310304bf..f629bbb2d0 100644 --- a/include/cutlass/conv/kernel/default_conv2d.h +++ b/include/cutlass/conv/kernel/default_conv2d.h @@ -128,6 +128,28 @@ struct DefaultConvEpilogueWithBroadcastSimt { >::Epilogue; }; +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; + template < typename ArchTag, typename Shape, diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 1c7f3444c4..9fbd97e585 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -76,7 +76,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -327,7 +327,6 @@ struct DefaultConv2dFprop < >; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage @@ -1167,7 +1166,11 @@ struct DefaultConv2dFprop < WarpMmaTensorOp, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1628,7 +1631,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1741,7 +1748,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1751,7 +1762,6 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, conv::Operator::kFprop >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1853,7 +1863,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1967,7 +1981,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h index 4100c8dd46..8589ace029 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -76,7 +76,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv2dFpropFusion; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h index b0e0ae6592..76bc12886c 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h @@ -69,7 +69,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index 6e6127d7f6..0825789ced 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -31,7 +31,7 @@ /*! \file \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. */ @@ -71,7 +71,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -143,6 +143,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB @@ -164,7 +165,7 @@ struct DefaultConv2dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport, AlignmentA, AlignmentB @@ -184,7 +185,7 @@ struct DefaultConv2dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport, AlignmentA, AlignmentB diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index 978d49c95b..e6e8a82209 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -72,7 +72,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements diff --git a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h index 927f70ce80..e2deaf6fe2 100644 --- a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -77,7 +77,7 @@ template < typename MathOperatorTag, conv::GroupMode GroupMode, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop.h b/include/cutlass/conv/kernel/default_conv3d_fprop.h index 3ea1e11c74..41fdd64a5e 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -73,7 +73,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv3dFprop; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -94,7 +94,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -113,7 +114,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -202,7 +204,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -221,7 +224,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -306,7 +310,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -325,7 +330,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -416,7 +422,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -435,7 +442,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -492,7 +500,11 @@ struct DefaultConv3dFprop < WarpMmaTensorOp, 1, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -526,7 +538,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -545,7 +558,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -598,7 +612,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -632,7 +650,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -651,7 +670,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -706,7 +726,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -738,7 +762,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -757,7 +782,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -813,7 +839,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -845,7 +875,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -864,7 +895,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -922,7 +954,11 @@ struct DefaultConv3dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel @@ -933,10 +969,10 @@ struct DefaultConv3dFprop < conv::Operator::kFprop, Conv3dProblemSize >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace kernel } // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h index 3449771530..d0457d572e 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -77,7 +77,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv3dFpropFusion; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h index 1d70a29e73..38e4de5c26 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -31,7 +31,7 @@ /*! \file \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. */ @@ -71,7 +71,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -142,6 +142,7 @@ template < typename ThreadblockSwizzle, int Stages, typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, conv::StrideSupport StrideSupport, int AlignmentA, int AlignmentB @@ -163,7 +164,7 @@ struct DefaultConv3dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport, AlignmentA, AlignmentB @@ -183,7 +184,7 @@ struct DefaultConv3dFpropWithBroadcast < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic, + IteratorAlgorithm, StrideSupport >::Kernel; diff --git a/include/cutlass/conv/kernel/default_deconv2d.h b/include/cutlass/conv/kernel/default_deconv2d.h new file mode 100644 index 0000000000..ace21b92fa --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv2d.h @@ -0,0 +1,983 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDeconv2d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h new file mode 100644 index 0000000000..d11432ed39 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv2d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv2dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_deconv3d.h b/include/cutlass/conv/kernel/default_deconv3d.h new file mode 100644 index 0000000000..e9eb4cc5b0 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv3d.h @@ -0,0 +1,525 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultDeconv3d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h new file mode 100644 index 0000000000..5c50c766d9 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h @@ -0,0 +1,303 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv3d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv3dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_depthwise_fprop.h b/include/cutlass/conv/kernel/default_depthwise_fprop.h index cbe84b1e78..aa4f2c359c 100644 --- a/include/cutlass/conv/kernel/default_depthwise_fprop.h +++ b/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -80,7 +80,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -109,7 +109,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, // MatrixShape typename StrideShape = cutlass::MatrixShape<-1, -1>, // MatrixShape< Height, Width> diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h index a3468bda6e..5e4299564f 100644 --- a/include/cutlass/conv/kernel/direct_convolution.h +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -155,7 +155,7 @@ struct DirectConvolutionParams { swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); // Dynamic SMEM usage because stride and dilation are runtime params. - smem_size_ = (iterator_A.activation_size * kStages + iterator_B.filter_size); + smem_size_ = (max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); } CUTLASS_HOST_DEVICE diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index f8cee81073..c4de265e8d 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -61,7 +61,7 @@ template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode > @@ -233,9 +233,9 @@ struct ImplicitGemmConvolution { ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), ptr_D(args.ref_D.data()), output_op(args.output_op), semaphore(semaphore), @@ -397,7 +397,6 @@ struct ImplicitGemmConvolution { threadblock_offset ); - // Construct the epilogue Epilogue epilogue( shared_storage.epilogue, diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index ded39ffa84..c768a2966e 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -61,7 +61,7 @@ template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem > struct ImplicitGemmConvolutionWithFusedEpilogue { diff --git a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp index c6996f15be..43c6d5959b 100644 --- a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -81,7 +81,6 @@ class ConvUniversal< using MainloopParams = typename CollectiveMainloop::Params; static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; static_assert(ArchTag::kMinComputeCapability >= 90); - // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; using ElementC = typename CollectiveEpilogue::ElementC; diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index 968c91e28e..1725db5af5 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -67,7 +67,8 @@ template < typename Layout_, typename ThreadMap_, typename AccessType_ = cutlass::AlignedArray, - conv::GroupMode GroupMode_ = conv::GroupMode::kNone + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, + bool IsDeconv_ = false > class Conv2dFpropFilterTileAccessIteratorAnalytic { public: @@ -85,6 +86,7 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; @@ -95,7 +97,7 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), "Vectors implied by the thread map must be divisible by the access type."); - + // // Simplifying assertions // @@ -152,13 +154,16 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + if (kGroupMode != conv::GroupMode::kNone) { filter_c_init_ = filter_c_; if (kGroupMode == conv::GroupMode::kDepthwise){ channels_per_group_ = 1; crs_per_group_ = problem_size_.S * problem_size_.R; } else { - channels_per_group_ = problem_size_.C / problem_size_.groups; + channels_per_group_ = input_channels / problem_size_.groups; crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); } } @@ -167,7 +172,7 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { - group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (problem_size_.K / problem_size_.groups); + group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (output_channels / problem_size_.groups); } } @@ -241,12 +246,15 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { TensorCoord coord = at(); + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + if (kGroupMode == conv::GroupMode::kNone) { - return coord.n() < problem_size_.K && coord.c() < problem_size_.C; + return coord.n() < output_channels && coord.c() < input_channels; } else if (kGroupMode == conv::GroupMode::kDepthwise) { - return coord.n() < problem_size_.K && coord.c() < 1; // channels_per_group_ is always equal to ONE. + return coord.n() < output_channels && coord.c() < 1; // channels_per_group_ is always equal to ONE. } else { - return coord.n() < problem_size_.K && coord.c() < channels_per_group_ && + return coord.n() < output_channels && coord.c() < channels_per_group_ && group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; } } @@ -289,19 +297,22 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { CUTLASS_HOST_DEVICE static Status can_implement(Conv2dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension - if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + if ((input_channels / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } if (platform::is_same>::value) { - if (problem_size.K % 32) { + if (output_channels % 32) { return Status::kErrorInvalidProblem; } } if (platform::is_same>::value) { - if (problem_size.K % 64) { + if (output_channels % 64) { return Status::kErrorInvalidProblem; } } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 3fc640b1be..4c2343c32c 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -67,7 +67,8 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + bool IsDeconv_ = false > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -85,6 +86,7 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; @@ -176,11 +178,11 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); Index column = threadblock_offset.column() + thread_coord.strided(); - channels_per_group_ = problem_size_.C / problem_size_.groups; + channels_per_group_ = (IsDeconv ? problem_size_.K : problem_size_.C) / problem_size_.groups; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { @@ -287,19 +289,22 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ CUTLASS_HOST_DEVICE static Status can_implement(Conv2dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension - if ((problem_size.C / problem_size.groups) % AccessType::kElements) { + if ((input_channels / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } if (platform::is_same>::value) { - if (problem_size.K % 32) { + if (output_channels % 32) { return Status::kErrorInvalidProblem; } } if (platform::is_same>::value) { - if (problem_size.K % 64) { + if (output_channels % 64) { return Status::kErrorInvalidProblem; } } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index 5ef1ab5f05..85dd37ffdb 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -64,7 +64,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + bool IsDeconv_ = false > class Conv3dFpropFilterTileAccessIteratorAnalytic { public: @@ -82,6 +83,7 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; @@ -198,8 +200,11 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { TensorCoord coord = at(); - return coord.n() < problem_size_.K && - coord.c() < problem_size_.C; + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + return coord.n() < output_channels && + coord.c() < input_channels; } /// Returns a pointer to the vector starting at the current coordinate @@ -234,8 +239,10 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { CUTLASS_HOST_DEVICE static Status can_implement(ConvProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (input_channels % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } return Status::kSuccess; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index eb51ceb650..ac49cf0781 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -66,7 +66,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool IsDeconv_ = false > class Conv3dFpropFilterTileAccessIteratorOptimized{ public: @@ -84,6 +85,7 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; @@ -172,11 +174,11 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); predicates_ |= (pred << s); } - if (filter_c_ >= problem_size.C) { + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { predicates_ = 0u; } @@ -214,7 +216,7 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ filter_c_ += params_.filter_c_delta; } - if (filter_c_ >= problem_size_.C) { + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { predicates_ = 0; } @@ -259,8 +261,10 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ CUTLASS_HOST_DEVICE static Status can_implement(Conv3dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (input_channels % (128/sizeof_bits::value)) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index d8c1d41a79..d778046c2f 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -32,16 +32,6 @@ \brief A Coord is a coordinate of arbitrary rank into a tensor or matrix */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index e7c96d05b3..40ae22246a 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -31,15 +31,6 @@ /*! \file \brief Helpers for printing cutlass/core objects */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #include diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index 88d718e013..f396528307 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -33,16 +33,6 @@ \brief Basic include for CUTLASS. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/detail/helper_macros.hpp" diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp index 10aad81de0..a14696b2f8 100644 --- a/include/cutlass/epilogue/collective/collective_builder.hpp +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -61,7 +61,7 @@ template < class ElementD, class GmemLayoutTagD, int AlignmentD, - class Schedule, + class EpilogueScheduleType, class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, class Enable = void > diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index f23a74e01f..9eb4c4b123 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -723,7 +723,8 @@ class CollectiveEpilogue< } // Vectorized fragment loop with visitor callback entry point - int r2s_v = epi_n * size(tRS_rD_frg); + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD_frg); CUTLASS_PRAGMA_UNROLL for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index e3160fa132..b8cac85634 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -634,17 +634,19 @@ struct Sm90AuxLoad< return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, ResidueMN residue_mn_, Params const& params_) + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ResidueMN residue_mn_, Params const& params_) : tC_rAux(cute::forward(tC_rAux_)), tC_gAux(cute::forward(tC_gAux_)), + tC_cAux(tC_cAux_), residue_mn(residue_mn_), params(params_) {} RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ResidueMN residue_mn; Params const& params; @@ -657,8 +659,18 @@ struct Sm90AuxLoad< } } - if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { // (partially) in-bounds CTA tile - copy_aligned(tC_gAux, tC_rAux); + constexpr int V = cute::min(Alignment, decltype(max_common_vector(tC_rAux, tC_gAux))::value); + if constexpr (V > 0) { + using VecType = uint_bit_t; + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_cAux_vec = tC_cAux.compose(make_layout(Int{}, Int{})); // only works if vector is logically sequential + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_gAux_vec, tC_rAux_vec); + } + else { + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_gAux, tC_rAux); } } } @@ -672,9 +684,8 @@ struct Sm90AuxLoad< } } - if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { - copy_aligned(tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); - } + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); } } @@ -723,8 +734,8 @@ struct Sm90AuxLoad< } } - return ConsumerStoreCallbacks( - cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); } }; diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 1a226a7575..c37f2b9ab5 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -172,7 +172,7 @@ struct ReLu> { template struct Clamp { struct Arguments { - T lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::min(); + T lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); T upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); }; diff --git a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h index 192bc6d157..7456ae8df4 100644 --- a/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h +++ b/include/cutlass/epilogue/thread/linear_combination_bias_elementwise.h @@ -63,8 +63,52 @@ template struct kIsHeavy_member_or_false::type> { static constexpr bool value = Op::kIsHeavy; }; + } // namespace (anonymous) +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct EmptyArguments {}; + +template +struct ElementwiseOpDispatcher { + using Arguments = EmptyArguments; + + T op; + + CUTLASS_HOST_DEVICE + ElementwiseOpDispatcher(Arguments) {} + + template + CUTLASS_HOST_DEVICE + ValueType operator()(ValueType value) { + return op(value); + } +}; + +template +struct ElementwiseOpDispatcher> { + using Arguments = typename T::Arguments; + + Arguments args; + T op; + + CUTLASS_HOST_DEVICE + ElementwiseOpDispatcher(Arguments args_):args(args_) {} + + template + CUTLASS_HOST_DEVICE + ValueType operator()(ValueType value) { + return op(value, args); + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// This base class is meant to define the concept required of the /// EpilogueWithBroadcast::OutputOp template < @@ -95,9 +139,13 @@ class LinearCombinationBiasElementwise { using ElementwiseOp = ElementwiseOp_; using BinaryOp = BinaryOp_; + using ElementwiseOpDispatcher = detail::ElementwiseOpDispatcher; + using ElementwiseArguments = typename ElementwiseOpDispatcher::Arguments; + // Indicates that this epilogue applies only one binary operation static bool const kIsSingleSource = true; + using FragmentAccumulator = Array; using FragmentCompute = Array; using FragmentC = Array; @@ -127,6 +175,7 @@ class LinearCombinationBiasElementwise { ElementCompute beta; ///< scales source tensor ElementCompute const *alpha_ptr; ///< pointer to accumulator scalar - if not null, loads it from memory ElementCompute const *beta_ptr; ///< pointer to source scalar - if not null, loads it from memory + ElementwiseArguments elementwise; ///< Arguments for elementwise operation // // Methods @@ -142,8 +191,9 @@ class LinearCombinationBiasElementwise { CUTLASS_HOST_DEVICE Params( ElementCompute alpha, - ElementCompute beta - ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr) { + ElementCompute beta, + ElementwiseArguments elementwise_ = ElementwiseArguments{} + ): alpha(alpha), beta(beta), alpha_ptr(nullptr), beta_ptr(nullptr), elementwise(elementwise_) { } @@ -157,8 +207,9 @@ class LinearCombinationBiasElementwise { CUTLASS_HOST_DEVICE Params( ElementCompute const *alpha_ptr, - ElementCompute const *beta_ptr - ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr) { + ElementCompute const *beta_ptr, + ElementwiseArguments elementwise_ = ElementwiseArguments{} + ): alpha(0), beta(0), alpha_ptr(alpha_ptr), beta_ptr(beta_ptr), elementwise(elementwise_) { } @@ -178,6 +229,7 @@ class LinearCombinationBiasElementwise { ElementCompute alpha_; ElementCompute beta_; + ElementwiseArguments const &elementwise_; bool skip_elementwise_; public: @@ -188,7 +240,7 @@ class LinearCombinationBiasElementwise { /// Constructor from Params CUTLASS_HOST_DEVICE - LinearCombinationBiasElementwise(Params const ¶ms) { + LinearCombinationBiasElementwise(Params const ¶ms): elementwise_(params.elementwise) { alpha_ = (params.alpha_ptr ? *params.alpha_ptr : params.alpha); beta_ = (params.beta_ptr ? *params.beta_ptr : params.beta); @@ -290,7 +342,7 @@ class LinearCombinationBiasElementwise { FragmentC const &frag_C, FragmentCompute const &V) const { - ElementwiseOp elementwise_op; + ElementwiseOpDispatcher elementwise_op(elementwise_); BinaryOp binary_op; FragmentCompute tmp_Accum = NumericArrayConverter()(AB); @@ -322,7 +374,7 @@ class LinearCombinationBiasElementwise { FragmentAccumulator const &AB, FragmentCompute const &V) const { - ElementwiseOp elementwise_op; + ElementwiseOpDispatcher elementwise_op(elementwise_); BinaryOp binary_op; FragmentCompute tmp_Accum = NumericArrayConverter()(AB); diff --git a/include/cutlass/epilogue/thread/linear_combination_clamp.h b/include/cutlass/epilogue/thread/linear_combination_clamp.h index fac1a4af85..5e1c847d22 100644 --- a/include/cutlass/epilogue/thread/linear_combination_clamp.h +++ b/include/cutlass/epilogue/thread/linear_combination_clamp.h @@ -432,17 +432,12 @@ class LinearCombinationClamp { intermediate = mul_add_accumulator(alpha_, converted_accumulator, intermediate); // D = alpha * Accum + X } - // Convert floats back to INT - FragmentAccumulator scaled_accumulator; - - NumericArrayConverter compute_converter; - - scaled_accumulator = compute_converter(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; + // + // Convert float => ElementOutput_ with clamping + // + NumericArrayConverter destination_converter; - return destination_converter(scaled_accumulator); + return destination_converter(intermediate); } /// Computes linear scaling: D = alpha * accumulator @@ -466,17 +461,12 @@ class LinearCombinationClamp { intermediate = mul_add_accumulator(alpha_, converted_accumulator); // D = alpha * Accum } - // Convert floats back to INT - FragmentAccumulator scaled_accumulator; - - NumericArrayConverter compute_converter; - - scaled_accumulator = compute_converter(intermediate); - - // Convert to destination numeric type - NumericArrayConverter destination_converter; + // + // Convert float => ElementOutput_ with clamping + // + NumericArrayConverter destination_converter; - return destination_converter(scaled_accumulator); + return destination_converter(intermediate); } }; diff --git a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h index a6e2a3dcf4..ff32f13b0b 100644 --- a/include/cutlass/epilogue/thread/linear_combination_planar_complex.h +++ b/include/cutlass/epilogue/thread/linear_combination_planar_complex.h @@ -156,23 +156,24 @@ class LinearCombinationPlanarComplex { NumericArrayConverter source_converter; NumericArrayConverter accumulator_converter; - ComputeFragment converted_source( + ComputeFragment converted_source{ source_converter(source.real), - source_converter(source.imag)); + source_converter(source.imag)}; - ComputeFragment converted_accumulator( + ComputeFragment converted_accumulator{ accumulator_converter(accumulator.real), - accumulator_converter(accumulator.imag)); - - // Perform binary operations - ComputeFragment intermediate; + accumulator_converter(accumulator.imag)}; multiplies > mul_op; multiply_add > mul_add_op; + // Perform binary operations + // complex multiply: I = beta * C - intermediate.real = mul_op(beta_.real(), converted_source.real); - intermediate.imag = mul_op(beta_.real(), converted_source.imag); + ComputeFragment intermediate { + mul_op(beta_.real(), converted_source.real), + mul_op(beta_.real(), converted_source.imag) + }; intermediate.real = mul_add_op(-beta_.imag(), converted_source.imag, intermediate.real); intermediate.imag = mul_add_op( beta_.imag(), converted_source.real, intermediate.imag); @@ -187,9 +188,9 @@ class LinearCombinationPlanarComplex { // Convert to destination numeric type NumericArrayConverter destination_converter; - return FragmentOutput( + return FragmentOutput{ destination_converter(intermediate.real), - destination_converter(intermediate.imag)); + destination_converter(intermediate.imag)}; } /// Computes linear scaling: D = alpha * accumulator + beta * source @@ -200,19 +201,19 @@ class LinearCombinationPlanarComplex { // Convert source to interal compute numeric type NumericArrayConverter accumulator_converter; - ComputeFragment converted_accumulator( + ComputeFragment converted_accumulator{ accumulator_converter(accumulator.real), - accumulator_converter(accumulator.imag)); + accumulator_converter(accumulator.imag)}; // Perform binary operations - ComputeFragment intermediate; - multiplies > mul_op; multiply_add > mul_add_op; // complex multiply-add: I = alpha * AB + I - intermediate.real = mul_op(alpha_.real(), converted_accumulator.real); - intermediate.imag = mul_op(alpha_.real(), converted_accumulator.imag); + ComputeFragment intermediate { + mul_op(alpha_.real(), converted_accumulator.real), + mul_op(alpha_.real(), converted_accumulator.imag) + }; intermediate.real = mul_add_op(-alpha_.imag(), converted_accumulator.imag, intermediate.real); intermediate.imag = mul_add_op( alpha_.imag(), converted_accumulator.real, intermediate.imag); @@ -220,9 +221,9 @@ class LinearCombinationPlanarComplex { // Convert to destination numeric type NumericArrayConverter destination_converter; - return FragmentOutput( + return FragmentOutput{ destination_converter(intermediate.real), - destination_converter(intermediate.imag)); + destination_converter(intermediate.imag)}; } }; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h index 61892fd2f9..f3119fa407 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_simt.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_simt.h @@ -64,11 +64,12 @@ #include "cutlass/transform/pitch_linear_thread_map.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_direct_conv.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" -#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h" #include "cutlass/epilogue/threadblock/epilogue.h" #include "cutlass/epilogue/threadblock/epilogue_depthwise.h" @@ -89,7 +90,9 @@ template < typename OutputOp_, int ElementsPerAccess, bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 > struct DefaultEpilogueSimt { @@ -102,6 +105,8 @@ struct DefaultEpilogueSimt { using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaSimt::LayoutC; using ElementAccumulator = typename WarpMmaSimt::ElementC; + static conv::StrideSupport const kStrideSupport = StrideSupport; + static int const kRank = Rank; // // Thread map @@ -116,13 +121,29 @@ struct DefaultEpilogueSimt { kElementsPerAccess >::Type; - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + static bool const UseCUDAStore = platform::is_same::value; + + using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, ScatterD, - PermuteDLayout + PermuteDLayout, + UseCUDAStore >; + using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv< + OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore, + kRank + >; + + using OutputTileIterator = typename platform::conditional::type; + using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorSimt< typename WarpMmaSimt::Shape, typename WarpMmaSimt::ThreadMma, @@ -389,7 +410,7 @@ struct DefaultDirectConvEpilogueSimt { typename WarpMmaSimt::Policy >; - using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLiner< + using SharedLoadIterator = cutlass::epilogue::threadblock::SharedLoadIteratorPitchLinear< OutputTileThreadMap, ElementAccumulator >; diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h index f3b006a16a..1692cc3093 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h @@ -66,6 +66,7 @@ #include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" #include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h" #include "cutlass/epilogue/threadblock/shared_load_iterator.h" @@ -480,7 +481,9 @@ template < typename OutputOp_, int ElementsPerAccess, bool ScatterD = false, - typename PermuteDLayout = layout::NoPermute + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 > struct DefaultEpilogueTensorOp { @@ -493,6 +496,8 @@ struct DefaultEpilogueTensorOp { using ElementOutput = typename OutputOp::ElementOutput; using LayoutC = typename WarpMmaTensorOp::LayoutC; using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + static conv::StrideSupport const kStrideSupport = StrideSupport; + static int const kRank = Rank; // // Thread map @@ -508,7 +513,7 @@ struct DefaultEpilogueTensorOp { static bool const UseCUDAStore = platform::is_same::value; - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + using PackedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< OutputTileThreadMap, ElementOutput, ScatterD, @@ -516,6 +521,19 @@ struct DefaultEpilogueTensorOp { UseCUDAStore >; + using StridedOutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorConv< + OutputTileThreadMap, + ElementOutput, + ScatterD, + PermuteDLayout, + UseCUDAStore, + kRank + >; + + using OutputTileIterator = typename platform::conditional::type; + using AccumulatorFragmentIterator = typename platform::conditional::value, cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< typename WarpMmaTensorOp::Shape, diff --git a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h index 550354b3d6..d21382b41f 100644 --- a/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h +++ b/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h @@ -114,7 +114,61 @@ struct DefaultEpilogueWithBroadcastSimt { typename Base::Padding >; }; +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for strided dgrad epilogues for SimtOps. +template < + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess, + bool ScatterD = false, + typename PermuteDLayout = layout::NoPermute +> +struct DefaultEpilogueWithBroadcastSimtStridedDgrad { + /// Use defaults related to the existing epilogue + using Base = DefaultEpilogueSimtStridedDgrad< + Shape, + WarpMmaSimt, + OutputOp, + ElementsPerAccess + >; + + // + // Stores the result z = (y = GEMM(A, B, C), broadcast) + // + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + typename Base::OutputTileThreadMap, + ElementOutput + >; + + // + // Additional tensor tile iterator - stores t = Elementwise(z) + // + using TensorTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorStridedDgrad< + typename Base::OutputTileThreadMap, + ElementTensor + >; + + /// Define the epilogue + using Epilogue = EpilogueWithBroadcast< + Shape, + WarpMmaSimt, + Base::kPartitionsK, + OutputTileIterator, + TensorTileIterator, + ElementVector, + typename Base::AccumulatorFragmentIterator, + typename Base::WarpTileIterator, + typename Base::SharedLoadIterator, + OutputOp, + typename Base::Padding + >; +}; //////////////////////////////////////////////////////////////////////////////// /// Defines sensible defaults for epilogues for TensorOps. diff --git a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h index 5780623ed1..0f417485e2 100644 --- a/include/cutlass/epilogue/threadblock/output_iterator_parameter.h +++ b/include/cutlass/epilogue/threadblock/output_iterator_parameter.h @@ -70,7 +70,6 @@ struct ConvOutputIteratorParameter { static int const kTensorStrideIdx = (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradStrideIdx : 0); - CUTLASS_HOST_DEVICE static OutputIteratorLayout layout(const TensorRef & ref) { return ref.stride(kTensorStrideIdx); @@ -80,10 +79,59 @@ struct ConvOutputIteratorParameter { static OutputTensorCoord extent(ConvProblemSize problem_size) { return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); } +}; + +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNHWC; + using OutputIteratorLayout = layout::TensorNHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kFprop; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } }; +template< + typename TensorRef_, ///! Input tensor to epilogue output iterator + typename ConvProblemSize_ ///! Convolutional operator on 2D or 3D problem +> +struct ConvOutputIteratorParameter { + + using TensorLayout = layout::TensorNDHWC; + using OutputIteratorLayout = layout::TensorNDHWC; + using MappedLayout = layout::RowMajor; + using OutputTensorCoord = typename OutputIteratorLayout::TensorCoord; + using MappedTensorCoord = typename MappedLayout::TensorCoord; + using TensorRef = TensorRef_; + static conv::Operator const kConvolutionalOperator = conv::Operator::kFprop; + using ConvProblemSize = ConvProblemSize_; + + CUTLASS_HOST_DEVICE + static OutputIteratorLayout layout(const TensorRef & ref) { + return ref.stride(); + } + CUTLASS_HOST_DEVICE + static MappedTensorCoord extent(ConvProblemSize problem_size) { + return conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(); + } +}; template < int InterleavedK, diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h index e3330de8a2..14a854476e 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h @@ -51,6 +51,8 @@ #include "cutlass/arch/arch.h" #include "cutlass/arch/memory.h" #include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" //////////////////////////////////////////////////////////////////////////////// @@ -102,10 +104,10 @@ class PredicatedTileIterator { /// Fragment object using Fragment = Array< - Element, - ThreadMap::Iterations::kColumn * - ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; /// Memory access size @@ -123,13 +125,27 @@ class PredicatedTileIterator { Params() { } CUTLASS_HOST_DEVICE - Params(Layout const &layout): + Params(Layout const &layout): PredicatedTileIteratorParams( layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, make_OutputTileThreadMapDesc() ) { } + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + conv::Conv2dProblemSize const &problem_size): + Params(layout) + { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + conv::Conv3dProblemSize const &problem_size): + Params(layout) + { } + CUTLASS_HOST_DEVICE Params(Base const &base) : Base(base) { } @@ -202,7 +218,7 @@ class PredicatedTileIterator { int state_[3]; /// Scatter indices - int const *indices_; + int const *indices_; /// PermuteDLayout PermuteDLayout permute_layout_; @@ -253,7 +269,7 @@ class PredicatedTileIterator { CUTLASS_PRAGMA_UNROLL for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = ((thread_offset.column() + mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); } @@ -267,8 +283,8 @@ class PredicatedTileIterator { } // Initialize byte_pointer_ - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; if (ScatterD) { @@ -306,7 +322,7 @@ class PredicatedTileIterator { CUTLASS_PRAGMA_UNROLL for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); int row_offset = row * ThreadMap::Delta::kRow @@ -330,7 +346,7 @@ class PredicatedTileIterator { bool guard = row_guard && mask_.predicates[column]; cutlass::arch::global_load< - AccessType, + AccessType, sizeof(AccessType) >( frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + @@ -380,11 +396,11 @@ class PredicatedTileIterator { CUTLASS_PRAGMA_UNROLL for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = + int frag_row_idx = (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - int row_offset = row * ThreadMap::Delta::kRow - + group * ThreadMap::Delta::kGroup + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + cluster * ThreadMap::Delta::kCluster; bool row_guard = ((row_offset + thread_start_row_) < extent_row_); @@ -426,7 +442,7 @@ class PredicatedTileIterator { (void *)&memory_pointer[0], guard); } - + if (!PermuteD) { memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); } @@ -649,7 +665,7 @@ class PredicatedTileIterator { } thread_start_row_ += ThreadMap::Shape::kRow; - + if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; @@ -663,7 +679,7 @@ class PredicatedTileIterator { store_byte_pointer_ += params_.advance_group; } - thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; if (state_[1] == ThreadMap::Count::kGroup) { @@ -679,7 +695,7 @@ class PredicatedTileIterator { store_byte_pointer_ += params_.advance_cluster; } - thread_start_row_ += ThreadMap::Count::kGroup * + thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { @@ -1121,6 +1137,14 @@ class InterleavedConvPredicatedTileIterator { initialize(layout.stride()); } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, + // Not needed. Added to be compatible with strided conv epilogue. + conv::Conv2dProblemSize const &problem_size): + Params(layout) + { } + }; /// Mask object diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h new file mode 100644 index 0000000000..c3c722bc4d --- /dev/null +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_conv.h @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/permute.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIteratorConv | ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + typename PermuteDLayout = layout::NoPermute, ///< Permute D operand or not + bool UseCUDAStore = false, + int Rank = 4 +> +class PredicatedTileIteratorConv { +public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + static int const kRank = Rank; + using Layout = typename platform::conditional::type; + + using Stride = typename Layout::Stride; + static int const kStrideRank = Layout::kStrideRank; + + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using MappedLayout = layout::RowMajor; + using Index = typename MappedLayout::Index; + using LongIndex = typename MappedLayout::LongIndex; + using TensorCoord = typename MappedLayout::TensorCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static bool constexpr PermuteD = !layout::is_trivial_permute; + + static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0"); + static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0"); + static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0"); + static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * + ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * + ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + /// Fast divmod objects divided by tensor extents + FastDivmod divmod[kStrideRank - 1]; + Stride tensor_stride; + + CUTLASS_HOST_DEVICE + Params() { } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, conv::Conv2dProblemSize const &problem_size): + PredicatedTileIteratorParams( + layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) { + divmod[0] = FastDivmod(problem_size.Q); + divmod[1] = FastDivmod(problem_size.P); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + tensor_stride[i] = layout.stride()[i]; + } + } + + CUTLASS_HOST_DEVICE + Params(Layout const &layout, conv::Conv3dProblemSize const &problem_size): + PredicatedTileIteratorParams( + layout.stride()[0] * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc() + ) { + divmod[0] = FastDivmod(problem_size.Q); + divmod[1] = FastDivmod(problem_size.P); + divmod[2] = FastDivmod(problem_size.Z); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStrideRank; ++i) { + tensor_stride[i] = layout.stride()[i]; + } + } + + CUTLASS_HOST_DEVICE + Params(Base const &base) : + Base(base) { } + }; + + /// Mask object + struct Mask { + + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + +private: + + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + Params params_; + + /// Byte-level pointer. This pointer is usually for both load() and store(), unless PermuteD is performed. When having PermuteD, byte_pointer_ is only for load(). + uint8_t *byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); + +private: + + // + // Methods + // + +public: + + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorConv( + Params const & params, + Element *pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord() + ): + params_(params) + { + + TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + + mask_.predicates[c] = ((thread_offset.column() + + ThreadMap::Delta::kColumn * c) < extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + // Initialize byte_pointer_ + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const { + + uint8_t *byte_pointer = byte_pointer_; + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + Stride tensor_coord = CoordinateDecompositionLittleEndian(row_offset + thread_start_row_, params_.divmod); + + LongIndex tensor_offset = dot(tensor_coord, params_.tensor_stride); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load< + AccessType, + sizeof(AccessType) + >( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + + column], + (void *)&memory_pointer[column * ThreadMap::Delta::kColumn / + kElementsPerAccess + tensor_offset / kElementsPerAccess], + guard); + } + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment &frag) const { + + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const { + uint8_t *byte_pointer = byte_pointer_; + AccessType const *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { + + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + + int frag_row_idx = + (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + Stride tensor_coord = CoordinateDecompositionLittleEndian((row_offset + thread_start_row_), params_.divmod); + + LongIndex tensor_offset = dot(tensor_coord, params_.tensor_stride); + + AccessType *memory_pointer = reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { + + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer[tensor_offset / kElementsPerAccess] = + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void *)&memory_pointer[tensor_offset / kElementsPerAccess], + guard); + } + + memory_pointer += (ThreadMap::Delta::kColumn / kElementsPerAccess); + } + } + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const &frag) const { + + store_with_byte_offset(frag, 0); + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorConv &operator++() { + + ++state_[0]; + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + + state_[0] = 0; + ++state_[1]; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + + state_[1] = 0; + ++state_[2]; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + + thread_start_row_ += ThreadMap::Shape::kGroup * ThreadMap::Shape::kRow + * ThreadMap::Shape::kCluster * ThreadMap::Shape::kTile; + } + } + } + + return *this; + } + + /// Advances a number of positions to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorConv &operator+=(int increment) + { + // Row + state_[0] += increment; + int increment_row = state_[0] / ThreadMap::Count::kRow; + state_[0] = state_[0] % ThreadMap::Count::kRow; + + thread_start_row_ += (ThreadMap::Shape::kRow * increment); + + // Group + state_[1] += increment_row; + int increment_group = state_[1] / ThreadMap::Count::kGroup; + state_[1] = state_[1] % ThreadMap::Count::kGroup; + + thread_start_row_ += + (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * + ThreadMap::Count::kRow * + increment_row; + + // Cluster + state_[2] += increment_group; + int increment_cluster = state_[2] / ThreadMap::Count::kCluster; + state_[2] = state_[2] % ThreadMap::Count::kCluster; + + thread_start_row_ += + ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * + ThreadMap::Count::kRow * + ThreadMap::Shape::kRow * + increment_group; + + // Tile + thread_start_row_ += + ThreadMap::Shape::kGroup * + ThreadMap::Shape::kRow * + ThreadMap::Shape::kCluster * + ThreadMap::Shape::kTile * + increment_cluster; + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask &mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const &mask) { + mask_ = mask; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h index 2b412bf12e..5e9aa22bdb 100644 --- a/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h +++ b/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h @@ -32,16 +32,6 @@ \brief */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" @@ -257,8 +247,6 @@ struct PredicatedTileIteratorParams { } }; - - /////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h b/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h similarity index 98% rename from include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h rename to include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h index 79a91f7518..5af6997ed3 100644 --- a/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_liner.h +++ b/include/cutlass/epilogue/threadblock/shared_load_iterator_pitch_linear.h @@ -66,7 +66,7 @@ namespace threadblock { template ::value / 8> -class SharedLoadIteratorPitchLiner { +class SharedLoadIteratorPitchLinear { public: using ThreadMap = ThreadMap_; using Element = Element_; @@ -123,7 +123,7 @@ class SharedLoadIteratorPitchLiner { /// Constructor CUTLASS_DEVICE - SharedLoadIteratorPitchLiner(TensorRef ref, int thread_idx) + SharedLoadIteratorPitchLinear(TensorRef ref, int thread_idx) : byte_pointer_(reinterpret_cast(ref.data())), stride_((ref.stride(0) * sizeof_bits::value) / 8), base_smem_address_(0) { diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 3e842d6a89..84fb06def2 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -28,15 +28,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index 0d925f268f..a2d062a04b 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -33,15 +33,6 @@ \brief Defines a class for using IEEE half-precision floating-point types in host or device code. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 27c8de2d89..964d2ff35f 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -33,20 +33,13 @@ This is inspired by the Standard Library's header. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" +#include + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) @@ -216,6 +209,35 @@ struct magnitude_squared_difference { } }; +// Computes the reciprocal square root +template +struct inverse_square_root; + +template <> +struct inverse_square_root { + CUTLASS_HOST_DEVICE + float operator()(float const &lhs) const { +#if defined(__CUDA_ARCH__) + return rsqrtf(lhs); +#else + return 1.f / std::sqrt(lhs); +#endif + } +}; + +template <> +struct inverse_square_root { + CUTLASS_HOST_DEVICE + half_t operator()(half_t const &lhs) const { +#if defined(__CUDA_ARCH__) + auto result = hrsqrt(reinterpret_cast<__half const &>(lhs)); + return reinterpret_cast(result); +#else + return half_t(1.f / std::sqrt(half_t::convert(lhs))); +#endif + } +}; + /// Divides template struct divides { @@ -546,8 +568,6 @@ struct bit_xor { } }; - - ////////////////////////////////////////////////////////////////////////////////////////////////// /// Atomic reductions diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp index 29ee9605c5..4613f7bf65 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -1093,7 +1093,7 @@ struct CollectiveMma< else if constexpr (ModeHasScales) { Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsS = thread_mma.partition_A(sS); - Tensor tCrS = make_fragment_like(thread_mma.partition_fragment_A(sS(_,_,Int<0>{}))); + Tensor tCrS = make_tensor(thread_mma.partition_fragment_A(sS(_,_,Int<0>{})).shape()); if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { return cute::make_tuple(tCsS, tCrS); @@ -1101,7 +1101,7 @@ struct CollectiveMma< else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_SCALE_K,PIPE) Tensor tCsZ = thread_mma.partition_A(sZ); - Tensor tCrZ = make_fragment_like(thread_mma.partition_fragment_A(sZ(_,_,Int<0>{}))); + Tensor tCrZ = make_tensor(thread_mma.partition_fragment_A(sZ(_,_,Int<0>{})).shape()); return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); } else { diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 750b16a66b..8d045d6efa 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -96,7 +96,7 @@ class GemmUniversalAdapter< using ElementB = typename GemmKernel::ElementB; using ElementC = typename GemmKernel::ElementC; using ElementD = typename GemmKernel::ElementD; - using ElementAccumulator = typename GemmKernel::TiledMma::ValTypeC; + using ElementAccumulator = typename GemmKernel::ElementAccumulator; using DispatchPolicy = typename GemmKernel::DispatchPolicy; using CollectiveMainloop = typename GemmKernel::CollectiveMainloop; using CollectiveEpilogue = typename GemmKernel::CollectiveEpilogue; @@ -361,9 +361,13 @@ class GemmUniversalAdapter< CUTLASS_ASSERT(cuda_adapter); if (cuda_adapter) { - launch_result = cuda_adapter->launch( - grid, cluster, block, smem_size, stream, kernel_params, 0 - ); + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + 0); } else { return Status::kErrorInternal; diff --git a/include/cutlass/gemm/gemm_enumerated_types.h b/include/cutlass/gemm/gemm_enumerated_types.h index efc93e551c..66aae898d7 100644 --- a/include/cutlass/gemm/gemm_enumerated_types.h +++ b/include/cutlass/gemm/gemm_enumerated_types.h @@ -32,16 +32,6 @@ \brief Defines common types used for all GEMM-like operators. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index f3f781a6d9..08b30c74cf 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -30,7 +30,7 @@ **************************************************************************************************/ /*! \file - \brief + \brief */ #pragma once @@ -177,8 +177,8 @@ class GemmUniversal< int const *ptr_scatter_D_indices = nullptr) : UniversalArgumentsBase(mode, problem_size, batch_count, batch_stride_D), - epilogue(epilogue), - ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), + epilogue(epilogue), + ptr_A(ptr_A), ptr_B(ptr_B), ptr_C(ptr_C), ptr_D(ptr_D), batch_stride_A(batch_stride_A), batch_stride_B(batch_stride_B), batch_stride_C(batch_stride_C), stride_a(stride_a), stride_b(stride_b), stride_c(stride_c), stride_d(stride_d), ptr_gather_A_indices(ptr_gather_A_indices), ptr_gather_B_indices(ptr_gather_B_indices), @@ -486,18 +486,18 @@ class GemmUniversal< int offset_k = 0; int problem_size_k = params.problem_size.k(); - ElementA *ptr_A = static_cast(params.ptr_A); + ElementA *ptr_A = static_cast(params.ptr_A); ElementB *ptr_B = static_cast(params.ptr_B); // // Fetch pointers based on mode. // - if (params.mode == GemmUniversalMode::kGemm || + if (params.mode == GemmUniversalMode::kGemm || params.mode == GemmUniversalMode::kGemmSplitKParallel) { if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { - problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; } offset_k = threadblock_tile_offset.k() * params.gemm_k_size; @@ -566,10 +566,10 @@ class GemmUniversal< // Compute threadblock-scoped matrix multiply-add mma( - gemm_k_iterations, - accumulators, - iterator_A, - iterator_B, + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, accumulators); // @@ -592,13 +592,13 @@ class GemmUniversal< int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); - ElementC *ptr_C = static_cast(params.ptr_C); + ElementC *ptr_C = static_cast(params.ptr_C); ElementC *ptr_D = static_cast(params.ptr_D); // // Fetch pointers based on mode. // - + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); @@ -606,7 +606,7 @@ class GemmUniversal< // If performing a reduction via split-K, fetch the initial synchronization if (params.grid_tiled_shape.k() > 1) { - + // Fetch the synchronization lock initially but do not block. semaphore.fetch(); @@ -647,14 +647,14 @@ class GemmUniversal< ); Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, + shared_storage.epilogue, + thread_idx, + warp_idx, lane_idx); // Wait on the semaphore - this latency may have been covered by iterator construction if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { - + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. if (threadblock_tile_offset.k()) { iterator_C = iterator_D; @@ -666,11 +666,11 @@ class GemmUniversal< // Execute the epilogue operator to update the destination tensor. epilogue( - output_op, - iterator_D, - accumulators, - iterator_C); - + output_op, + iterator_D, + accumulators, + iterator_C); + // // Release the semaphore // @@ -687,7 +687,7 @@ class GemmUniversal< // Otherwise, the semaphore is incremented lock = threadblock_tile_offset.k() + 1; } - + semaphore.release(lock); } } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp index 9229feee88..877d2c1ddf 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -69,7 +69,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp index 1da5b6d3da..abf79e842a 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -70,7 +70,6 @@ class GemmUniversal< using ProblemShape = ProblemShape_; static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); - // Mainloop derived types using CollectiveMainloop = CollectiveMainloop_; using TileShape = typename CollectiveMainloop::TileShape; diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index 0cc3ffc890..1630583f6c 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -35,16 +35,6 @@ \brief Parameters structures for persistent tile schedulers */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #include "cutlass/coord.h" #include "cutlass/kernel_hardware_info.h" #include "cutlass/workspace.h" diff --git a/include/cutlass/gemm/thread/mma_sm60.h b/include/cutlass/gemm/thread/mma_sm60.h index 35cf3fb769..5e2178982c 100644 --- a/include/cutlass/gemm/thread/mma_sm60.h +++ b/include/cutlass/gemm/thread/mma_sm60.h @@ -147,9 +147,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; + Array tmp { ptr_D[n*Shape::kM/2 + m] }; mma( tmp, @@ -157,7 +155,7 @@ struct Mma_HFMA2 < ptr_B[n*Shape::kK + k], tmp); - ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; + ptr_D[n*Shape::kM/2 + m] = tmp; } } } @@ -239,9 +237,7 @@ struct Mma_HFMA2< CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; Array tmp_B; tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); @@ -253,7 +249,7 @@ struct Mma_HFMA2< tmp_B, tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } @@ -335,10 +331,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for (int n = 0; n < Shape::kN / Mma::Shape::kN; ++n) { - Array tmp; - Array *ptr_tmp = &tmp; - - ptr_tmp[0] = ptr_D[m + n * Shape::kM/2]; + Array tmp { ptr_D[m + n * Shape::kM/2] }; mma( tmp, @@ -346,7 +339,7 @@ struct Mma_HFMA2 < ptr_B[k * Shape::kN + n], tmp); - ptr_D[m + n * Shape::kM/2] = ptr_tmp[0]; + ptr_D[m + n * Shape::kM/2] = tmp; } } } @@ -428,9 +421,7 @@ struct Mma_HFMA2< CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; mma( tmp, @@ -438,7 +429,7 @@ struct Mma_HFMA2< ptr_B[k*Shape::kN/2 + n], tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } @@ -521,9 +512,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; + Array tmp { ptr_D[n*Shape::kM/2 + m] }; Array tmp_A; tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); @@ -535,7 +524,7 @@ struct Mma_HFMA2 < ptr_B[n*Shape::kK + k], tmp); - ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; + ptr_D[n*Shape::kM/2 + m] = tmp; } } } @@ -617,9 +606,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; Array tmp_B; tmp_B[0] = ptr_B->at(2*n*Shape::kK + k); @@ -631,7 +618,7 @@ struct Mma_HFMA2 < tmp_B, tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } @@ -713,9 +700,7 @@ struct Mma_HFMA2 < CUTLASS_PRAGMA_UNROLL for(auto n=0; n < Shape::kN / Mma::Shape::kN; n++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[n*Shape::kM/2 + m]; + Array tmp { ptr_D[n*Shape::kM/2 + m] }; Array tmp_A; tmp_A[0] = ptr_A->at(2*m*Shape::kK + k); @@ -727,7 +712,7 @@ struct Mma_HFMA2 < ptr_B[k*Shape::kN + n], tmp); - ptr_D[n*Shape::kM/2 + m] = ptr_tmp[0]; + ptr_D[n*Shape::kM/2 + m] = tmp; } } } @@ -810,9 +795,7 @@ struct Mma_HFMA2< CUTLASS_PRAGMA_UNROLL for(auto m=0; m < Shape::kM / Mma::Shape::kM; m++){ - Array tmp; - Array *ptr_tmp = &tmp; - ptr_tmp[0] = ptr_D[m*Shape::kN/2 + n]; + Array tmp { ptr_D[m*Shape::kN/2 + n] }; mma( tmp, @@ -820,7 +803,7 @@ struct Mma_HFMA2< ptr_B[k*Shape::kN/2 + n], tmp); - ptr_D[m*Shape::kN/2 + n] = ptr_tmp[0]; + ptr_D[m*Shape::kN/2 + n] = tmp; } } } diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index 596fe5a403..b79e587d7c 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -32,16 +32,6 @@ \brief Implements streamk threadblock mapping blockIdx to GEMM problems. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/half.h b/include/cutlass/half.h index e22c8be36e..c203e6cb07 100644 --- a/include/cutlass/half.h +++ b/include/cutlass/half.h @@ -34,16 +34,6 @@ device code. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #ifndef CUTLASS_ENABLE_F16C diff --git a/include/cutlass/integer_subbyte.h b/include/cutlass/integer_subbyte.h index e06a0491ce..1a9728e7ab 100644 --- a/include/cutlass/integer_subbyte.h +++ b/include/cutlass/integer_subbyte.h @@ -34,16 +34,6 @@ device code. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index b69399ff0a..62dcb8b451 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -30,16 +30,6 @@ **************************************************************************************************/ #pragma once -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #if !defined(__CUDACC_RTC__) #include "cuda_runtime.h" diff --git a/include/cutlass/layout/matrix.h b/include/cutlass/layout/matrix.h index 8ece43784d..32aa17a5df 100644 --- a/include/cutlass/layout/matrix.h +++ b/include/cutlass/layout/matrix.h @@ -38,16 +38,6 @@ defined in cutlass/tensor_ref.h. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/layout/pitch_linear.h b/include/cutlass/layout/pitch_linear.h index 4063a988ef..8c9540f408 100644 --- a/include/cutlass/layout/pitch_linear.h +++ b/include/cutlass/layout/pitch_linear.h @@ -32,16 +32,6 @@ \brief Defines layout functions used by TensorRef and derived classes for pitch-linear memory. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index 4cef83337b..2a3a09549c 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -33,16 +33,6 @@ \brief Boost-like numeric conversion operator for CUTLASS numeric types */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if !defined(__CUDACC_RTC__) @@ -848,18 +838,21 @@ struct NumericArrayConverter result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + Array result; reinterpret_cast<__half2 &>(result) = __float22half2_rn(reinterpret_cast(source)); + return result; #else NumericConverter convert_; + // NOTE: cutlass::Array is NOT an aggregate type and + // below `{}` does NOT conduct zero initialization. Below `{}` will + // conduct default initialization (calling default ctr). We use this syntax + // to resolve compiler warning on uninitialized member variable. + Array result{}; result[0] = convert_(source[0]); result[1] = convert_(source[1]); + return result; #endif - - return result; } CUTLASS_HOST_DEVICE @@ -879,17 +872,19 @@ struct NumericArrayConverter { CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { - Array result; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) - reinterpret_cast(result) = __half22float2(reinterpret_cast<__half2 const &>(source)); + float2 result2 = __half22float2(reinterpret_cast<__half2 const &>(source)); + return { + float{result2.x}, + float{result2.y} + }; #else NumericConverter convert_; - result[0] = convert_(source[0]); - result[1] = convert_(source[1]); + return { + convert_(source[0]), + convert_(source[1]) + }; #endif - - return result; } CUTLASS_HOST_DEVICE @@ -1482,7 +1477,7 @@ struct NumericArrayConverterPacked4Element { for (int i = 0; i < 4; ++i) { if (platform::is_same::value) { result[i] = convert_(s[i]); - } + } else { // conjugate result[i] = conj(convert_(s[i])); } @@ -2306,34 +2301,59 @@ template < struct NumericArrayConverter : public PackedNumericArrayConverter {}; - - ///////////////////////////////////////////////////////////////////////////////////////////////// - /// Partial specialization for Array <= Array /// Conversion is performed with saturation regardless of setting of /// the `Round` template parameter. template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + // Convert to int to int8_t + NumericConverter destination_converter; + result_type result; + result[0] = destination_converter(source[0]); + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +// To convert a FP32 to Int that has less than 32 bits, we need to convert it to int32 first. +template < + typename T, int N, FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayFP32ToIntConverter { - using result_type = Array; + using result_type = Array; using source_type = Array; static FloatRoundStyle const round_style = Round; + static_assert(platform::numeric_limits::is_integer, "the dest type has to be int."); + CUTLASS_HOST_DEVICE static result_type convert(source_type const & source) { // Convert float to int Array temporary; - NumericArrayConverter compute_converter; + NumericArrayConverter compute_converter; temporary = compute_converter(source); // Convert to int to int8_t - NumericArrayConverter destination_converter; + NumericArrayConverter destination_converter; return destination_converter(temporary); } @@ -2343,6 +2363,91 @@ struct NumericArrayConverter { } }; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +template < + int N, + FloatRoundStyle Round +> +struct NumericArrayConverter { + + using result_type = Array; + using source_type = Array; + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & source) { + NumericArrayFP32ToIntConverter converter; + return converter(source); + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ @@ -2508,7 +2613,7 @@ namespace detail { template CUTLASS_DEVICE static void convert_helper( - typename ArrayConverter::result_type& result, + typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { using ElementRes = typename ArrayConverter::result_type::Element; @@ -2530,14 +2635,14 @@ namespace detail { static void convert_helper(typename ArrayConverter::result_type& result, typename ArrayConverter::source_type const& source) { static_assert(sizeof...(OtherVectorArrays) % 2 == 0, "Vector converters must come in {dst, src} pairs"); static_assert(ResultVectorArray::kElements == SourceVectorArray::kElements, "Vector converters must have the same vector width"); - static_assert(cutlass::platform::is_same::value, + static_assert(cutlass::platform::is_same::value, "ResultVectorArray must have the same type ArrayConverter::result_type"); - static_assert(cutlass::platform::is_same::value, + static_assert(cutlass::platform::is_same::value, "SourceVectorArray must have the same type ArrayConverter::result_type"); static_assert(Offset >= 0 && Offset <= ArrayConverter::result_type::kElements, "Offset must be between 0 and N"); static_assert(ParentWidth == 0 || ParentWidth > ResultVectorArray::kElements, "Vector arrays must be given in decreasing order of width"); - + constexpr int vector_width = ResultVectorArray::kElements; static_assert(ispow2(vector_width), "Vector width must be a power of 2"); @@ -2569,8 +2674,8 @@ namespace detail { public: /* - A method to convert vectors of elements using the packed_convert method of the converter. - + A method to convert vectors of elements using the packed_convert method of the converter. + Converters using this class must implement packed convert and support 1 or more vector conversions. */ template @@ -2651,7 +2756,7 @@ struct NumericArrayConverter uint32_t final_prmt_idx = final_prmt_base | sign; // This uses a look up table to convert packed int4s to packed fp8s, using the int4 value - // as the index to prmt. + // as the index to prmt. // It first select both the positive and negative candidates, then uses the sign bit to // select the correct candidate. asm volatile( @@ -2675,8 +2780,8 @@ struct NumericArrayConverter static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; @@ -2684,7 +2789,7 @@ struct NumericArrayConverter CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2771,7 +2876,7 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 1, 2, 4 or 8 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements PackedResultType r; @@ -2788,23 +2893,23 @@ struct NumericArrayConverter { return r; } - friend class detail::VectorizedConverter; + friend class detail::VectorizedConverter; public: CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2844,7 +2949,7 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + PackedResultType r; // View the input as reg uint32_t src_reg = to_reg(source); @@ -2875,15 +2980,15 @@ struct NumericArrayConverter { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -2923,12 +3028,12 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + PackedResultType r; // View the input as reg uint32_t src_reg = to_reg(source); - // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores // the result in r (without introducing extra cvt.u32.u8 instruction) uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; uint32_t* result_as_int = reinterpret_cast(&r); @@ -2948,15 +3053,15 @@ struct NumericArrayConverter { static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3010,10 +3115,10 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; + RegArray r; // View the input as reg uint32_t src_reg = to_reg(source); @@ -3034,7 +3139,7 @@ struct NumericArrayConverter { " prmt.b32 %0, %1, %2, %3;\n" "}\n" : "=r"(r[ii]) - : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); + : "r"(src_reg), "n"(0), "r"(prmt_indices[ii])); } // The below XOR does the following: @@ -3057,7 +3162,7 @@ struct NumericArrayConverter { " lop3.b32 %0, %0, %1, %2, %3;\n" "}\n" : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); } // We will issue 2 hfmas that do the following: @@ -3087,23 +3192,23 @@ struct NumericArrayConverter { return reinterpret_cast(r); } - friend class detail::VectorizedConverter; + friend class detail::VectorizedConverter; public: CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3145,10 +3250,10 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; + RegArray r; #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) auto result = reinterpret_cast(r); @@ -3176,18 +3281,18 @@ struct NumericArrayConverter { // In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve // the same result as add.s16x2 instruction. // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3) - // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to + // For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to // three predefined constant values as follows: // ta = 0xF0; // tb = 0xCC; // tc = 0xAA; // kImmLut = F(ta, tb, tc); - // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA - static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; + // If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA + static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA; for (int ii = 0; ii < RegArray::kElements; ++ii) { // The bit-wise operation executed below is `r[ii] = (r[ii] & 0x03FF03FF) ^ 0x66006600;` - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(r[ii]) : "r"(r[ii]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut)); } @@ -3209,14 +3314,14 @@ struct NumericArrayConverter { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3256,11 +3361,11 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; - + RegArray r; + // View the input as reg uint32_t src_reg = to_reg(source); uint32_t const prmt_indices[2] = {0x5150, 0x5352}; @@ -3289,15 +3394,15 @@ struct NumericArrayConverter { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3352,10 +3457,10 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2, 4 or 8 to use private convert dispatch."); - + // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray; - RegArray r; + RegArray r; // View the input as reg uint32_t src_reg = to_reg(source); @@ -3371,7 +3476,7 @@ struct NumericArrayConverter { " prmt.b32 %0, %1, %2, %3;\n" "}\n" : "=r"(r[ii]) - : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); } // The below XOR does the following: @@ -3390,7 +3495,7 @@ struct NumericArrayConverter { " lop3.b32 %0, %0, %1, %2, %3;\n" "}\n" : "+r"(r[ii]) - : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); } // We will issue 2 bfmas that do the following: @@ -3400,7 +3505,7 @@ struct NumericArrayConverter { // This is the BF16 {136, 136} represented as an integer. static constexpr uint32_t bias_rep = 0x43084308; const __nv_bfloat162& bias = reinterpret_cast(bias_rep); - + CUTLASS_PRAGMA_UNROLL for (int ii = 0; ii < RegArray::kElements; ++ii) { __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); @@ -3410,23 +3515,23 @@ struct NumericArrayConverter { return reinterpret_cast(r); } - friend class detail::VectorizedConverter; + friend class detail::VectorizedConverter; public: CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3466,7 +3571,7 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + NumericArrayConverter convert_int8_to_f32; Array tmp = convert_int8_to_f32(source); NumericArrayConverter convert_f32_to_bf16; @@ -3481,15 +3586,15 @@ struct NumericArrayConverter { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3529,7 +3634,7 @@ struct NumericArrayConverter { (platform::is_same::value && platform::is_same::value), "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private convert dispatch."); - + NumericArrayConverter convert_uint8_to_f32; Array tmp = convert_uint8_to_f32(source); NumericArrayConverter convert_f32_to_bf16_; @@ -3543,15 +3648,15 @@ struct NumericArrayConverter { static result_type convert(source_type const &source) { result_type result; using ConverterType = NumericArrayConverter; - detail::VectorizedConverter::convert(result, source); return result; } CUTLASS_DEVICE - result_type operator()(source_type const &s) const { + result_type operator()(source_type const &s) const { return convert(s); } }; @@ -3705,7 +3810,7 @@ struct PackPredicates { int word_idx = (i / kWordSize); int bit_idx = (i % kWordSize); - uint8_t mask = ((predicates[i] ? 1u : 0u) << bit_idx); + uint8_t mask = static_cast((predicates[i] ? 1u : 0u) << bit_idx); bytes[word_idx] = (bytes[word_idx] | mask); } return packed; diff --git a/include/cutlass/numeric_size.h b/include/cutlass/numeric_size.h index 46f343aa0f..42bc418a40 100644 --- a/include/cutlass/numeric_size.h +++ b/include/cutlass/numeric_size.h @@ -33,16 +33,6 @@ \brief Top-level include for all CUTLASS numeric types. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/numeric_types.h b/include/cutlass/numeric_types.h index cb7c20872c..5519fbe7c9 100644 --- a/include/cutlass/numeric_types.h +++ b/include/cutlass/numeric_types.h @@ -32,15 +32,6 @@ \file \brief Top-level include for all CUTLASS numeric types. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index b7d9e09325..2ab7ae0455 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -264,7 +264,6 @@ public : dim3 block_id = cute::block_id_in_cluster(); auto cluster_size = cute::size(cluster_shape); static constexpr int MaxClusterSize = 16; - static_assert(cluster_size <= MaxClusterSize, "ERROR : Cluster size too large !" ); // STEP 1 : Use Cute Layout function to generate an optimal dst block-id (0-15) if (params_.num_consumers % NumThreadsPerWarpGroup == 0) { diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 927f80cbd9..ba74ae723b 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -95,16 +95,6 @@ * counterparts (or trivially find-and-replace their occurrences in code text). */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ - //----------------------------------------------------------------------------- // Dependencies //----------------------------------------------------------------------------- @@ -159,7 +149,7 @@ /// builtin_unreachable #if !defined(CUTLASS_GCC_UNREACHABLE) -# if defined(__clang__) || defined(__GNUC__) +# if defined(__GNUC__) # define CUTLASS_GCC_UNREACHABLE __builtin_unreachable() # else # define CUTLASS_GCC_UNREACHABLE @@ -950,7 +940,6 @@ struct numeric_limits { static constexpr bool is_integer = true; }; -#if !defined(__CUDACC_RTC__) template <> struct numeric_limits { CUTLASS_HOST_DEVICE @@ -958,7 +947,6 @@ struct numeric_limits { static constexpr bool is_integer = false; static constexpr bool has_infinity = true; }; -#endif /// std::float_round_style using CUTLASS_STL_NAMESPACE::float_round_style; diff --git a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h index 22699a288a..c67af387e5 100755 --- a/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h +++ b/include/cutlass/transform/threadblock/predicated_tile_access_iterator_params.h @@ -32,16 +32,6 @@ \brief */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #include "cutlass/cutlass.h" diff --git a/include/cutlass/uint128.h b/include/cutlass/uint128.h index d1e6e64642..0a41e95cd3 100644 --- a/include/cutlass/uint128.h +++ b/include/cutlass/uint128.h @@ -32,15 +32,6 @@ \file \brief Defines an unsigned 128b integer with several operators to support 64-bit integer division. */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. -*/ #pragma once #if defined(__CUDACC_RTC__) diff --git a/include/cutlass/workspace.h b/include/cutlass/workspace.h index 82ff77c1fc..6dc0141cfb 100644 --- a/include/cutlass/workspace.h +++ b/include/cutlass/workspace.h @@ -32,16 +32,6 @@ \brief Utilities for initializing workspaces */ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - #pragma once #if !defined(__CUDACC_RTC__) diff --git a/media/docs/build/building_with_clang_as_host_compiler.md b/media/docs/build/building_with_clang_as_host_compiler.md index d14eb20e78..54b2c78e1f 100644 --- a/media/docs/build/building_with_clang_as_host_compiler.md +++ b/media/docs/build/building_with_clang_as_host_compiler.md @@ -36,17 +36,17 @@ is the following error when attempting to use clang: ## Required CMake options The Clang build requires specifying the following CMake options. -Replace `` with the path to your `clang++` executable, -and replace `` with the path to your `clang` executable -(which must have the same version as your `clang++` executable). -You may use `clang++` resp. `clang` directly if they are in your `PATH`. +Replace `` with the path to your `clang++` executable. +You may use `clang++` directly if it is in your `PATH`. * `CMAKE_CXX_COMPILER=` * `CMAKE_CUDA_HOST_COMPILER=` -* `CMAKE_C_COMPILER=` -Please note that both `CMAKE_CXX_COMPILER` and `CMAKE_C_COMPILER` -must be set, even though CUTLASS is a C++ project, not a C project. +One must set both! It's not enough just to set the `CXX` environment +variable, for example. Symptoms of only setting `CMAKE_CXX_COMPILER` +(or only setting the `CXX` environment variable) include `cc1plus` +(GCC's compiler executable) reporting build errors due to it not +understanding Clang's command-line options. Users can also specify a particular CUDA Toolkit version by setting the CMake option `CMAKE_CUDA_COMPILER` diff --git a/media/docs/cute/02_layout_algebra.md b/media/docs/cute/02_layout_algebra.md index f9c9d2fb68..3b70252b1d 100644 --- a/media/docs/cute/02_layout_algebra.md +++ b/media/docs/cute/02_layout_algebra.md @@ -317,23 +317,25 @@ The `complement` of a layout attempts to find another layout that represents the You can find many examples and checked post-conditions in [the `complement` unit test](../../../test/unit/cute/core/complement.cpp). The post-conditions include ```cpp -// @post cosize(make_layout(@a layout_a, @a result))) >= @a cosize_hi -// @post cosize(@a result) >= round_up(@a cosize_hi, cosize(@a layout_a)) +// @post cosize(make_layout(@a layout_a, @a result))) >= size(@a cotarget) +// @post cosize(@a result) >= round_up(size(@a cotarget), cosize(@a layout_a)) // @post for all i, 1 <= i < size(@a result), // @a result(i-1) < @a result(i) // @post for all i, 1 <= i < size(@a result), // for all j, 0 <= j < size(@a layout_a), // @a result(i) != @a layout_a(j) -Layout complement(LayoutA const& layout_a, Integral const& cosize_hi) +Layout complement(LayoutA const& layout_a, Shape const& cotarget) ``` -That is, the complement `R` of a layout `A` with respect to an integer `M` satisfies the following properties. -1. The size (and cosize) of `R` is *bounded* by `M`. +That is, the complement `R` of a layout `A` with respect to a Shape (IntTuple) `M` satisfies the following properties. +1. The size (and cosize) of `R` is *bounded* by `size(M)`. 2. `R` is *ordered*. That is, the strides of `R` are positive and increasing. This means that `R` is unique. 3. `A` and `R` have *disjoint* codomains. `R` attempts to "complete" the codomain of `A`. +The `cotarget` parameter above is most commonly an integer -- you can see we only use `size(cotarget)` above. However, sometimes it is useful to specify an integer that has static properties. For example, `28` is a dynamic integer and `(_4,7)` is a shape with size `28` that is statically known to be divisible by `_4`. Both will produce the same `complement` mathematically, but the extra information can used by `complement` to preserve the staticness of the result as much as possible. + ### Complement Examples -`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides can be found in the unit test. +`complement` is most effective on static shapes and strides, so consider all integers below to be static. Similar examples for dynamic shapes and strides as well as IntTuple `cotarget` can be found in [the unit test](../../../test/unit/cute/core/complement.cpp). * `complement(4:1, 24)` is `6:4`. Note that `(4,6):(1,4)` has cosize `24`. The layout `4:1` is effectively repeated 6 times with `6:4`. @@ -425,9 +427,9 @@ Layout Shape : (M, N, L, ...) Tiler Shape : logical_divide : ((TileM,RestM), (TileN,RestN), L, ...) -zipped_divide : ((TileM,TileN,...), (RestM,RestN,L,...)) -tiled_divide : ((TileM,TileN,...), RestM, RestN, L, ...) -flat_divide : (TileM, TileN, ..., RestM, RestN, L, ...) +zipped_divide : ((TileM,TileN), (RestM,RestN,L,...)) +tiled_divide : ((TileM,TileN), RestM, RestN, L, ...) +flat_divide : (TileM, TileN, RestM, RestN, L, ...) ``` For example, the `zipped_divide` function applies `logical_divide`, and then gathers the "subtiles" into a single mode and the "rest" into a single mode. diff --git a/media/docs/fundamental_types.md b/media/docs/fundamental_types.md index 8e17ad57b6..8bef0702f6 100644 --- a/media/docs/fundamental_types.md +++ b/media/docs/fundamental_types.md @@ -63,13 +63,12 @@ template < typename T, // element type int N // number of elements > -class Array; +struct Array; ``` `Array` defines a statically sized array of elements of type _T_ and size _N_. This class is similar to -[`std::array<>`](https://en.cppreference.com/w/cpp/container/array) in the Standard Library with two notable exceptions: -* constructors for each element may not be called -* partial specializations exist to pack or unpack elements smaller than one byte. +[`std::array<>`](https://en.cppreference.com/w/cpp/container/array) in the Standard Library with one notable exception: +partial specializations exist to pack or unpack elements smaller than one byte. `Array<>` is intended to be a convenient and uniform container class to store arrays of numeric elements regardless of data type or vector length. The storage needed is expected to be the minimum necessary given the logical size of each numeric type in bits (numeric types smaller than one byte are densely packed). Nevertheless, the size reported by `sizeof(Array)` is always an integer multiple of bytes. diff --git a/media/docs/profiler.md b/media/docs/profiler.md index d318e3d037..34282925db 100644 --- a/media/docs/profiler.md +++ b/media/docs/profiler.md @@ -210,7 +210,6 @@ GEMM [int] --inst_k,--instruction-shape::k Math instruction shape in the K dimension [int] --min_cc,--minimum-compute-capability Minimum device compute capability [int] --max_cc,--maximum-compute-capability Maximum device compute capability - Examples: Profile a particular problem size: diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 55f4e2f1ef..62ac6c272d 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -1654,7 +1654,7 @@ def extended_name_3x(self): extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a=DataTypeNames[self.A.element], element_b=DataTypeNames[self.B.element], - element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_acc=DataTypeNames[self.accumulator_type()], element_c=DataTypeNames[self.C.element], element_d=DataTypeNames[self.epilogue_functor.element_output], core_name=self.core_name()) diff --git a/python/cutlass/emit/common.py b/python/cutlass/emit/common.py index ef724c04d7..87025eead0 100644 --- a/python/cutlass/emit/common.py +++ b/python/cutlass/emit/common.py @@ -118,16 +118,18 @@ typename DeviceKernel::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, {M, N, K, L}, // problem size - A, // ptrA - cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A - B, // ptrB - cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B { + A, // ptrA + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A + B, // ptrB + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B + }, + { + {alpha, beta}, C, // ptrC cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C D, // ptrD cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D - {alpha, beta}, }, hw_info }; diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index 613311ac90..ac13e866fa 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -232,7 +232,7 @@ #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/kernel/gemm_universal.hpp" -#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/util/packed_stride.hpp" """, } @@ -583,7 +583,11 @@ '${name}_kernel.cu', ], include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'], - extra_compile_args=['-std=c++17'] + extra_compile_args={ + 'cxx': ['-std=c++17'], + 'nvcc': ['-std=c++17', ${extra_compile_args}], + }, + libraries=['cuda'] ), ], cmdclass={ @@ -593,7 +597,7 @@ """ -def _generate_setup(name: str, sourcedir: str): +def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""): """ Generates a setup.py file for the extension @@ -601,10 +605,12 @@ def _generate_setup(name: str, sourcedir: str): :type name: str :param sourcedir: directory to which generated source files should be written :type sourcedir: str + :param extra_compile_args: additional arguments to pass to setup.py + :type extra_args: str """ setup_py_file = os.path.join(sourcedir, "setup.py") setup_source = SubstituteTemplate( - _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH} + _PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args} ) with open(setup_py_file, "w") as outfile: outfile.write(setup_source) @@ -696,6 +702,7 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str): os.path.join(CUTLASS_PATH, "include"), os.path.join(CUTLASS_PATH, "tools/util/include"), ], + extra_ldflags=["-lcuda"], verbose=(logger.level == logging.DEBUG) ) return jitmodule @@ -759,7 +766,10 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = "" with open(cpp_file, "w") as outfile: outfile.write(cpp_source) - _generate_setup(name, sourcedir) + extra_compile_args = "" + if cc == 90: + extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'" + _generate_setup(name, sourcedir, extra_compile_args) if jit: return _jit(name, cc, cpp_file, cuda_file) diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index 3ef021f72f..7c16cc6855 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -137,9 +137,9 @@ def operations(self, alignment_A: int, alignment_B: int, alignment_C: int, math_ # Finally, go through all available alignment combinations and find # one for which all values are less than those passed in. key = None - alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) + alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True) for align_A, align_B, align_C in alignments: - if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C: + if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0: key = f"{align_A} {align_B} {align_C}" break diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index e486a43195..e74c40786f 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -712,4 +712,4 @@ def run(self, A=None, B=None, C=None, D=None, if sync: arguments.sync() - return arguments \ No newline at end of file + return arguments diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index 9850b68af7..f739c15ab8 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -205,7 +205,7 @@ def extended_name_3x(self): extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a = DataTypeNames[self.A.element], element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_acc = DataTypeNames[self.accumulator_type()], element_c = DataTypeNames[self.C.element], element_d = DataTypeNames[self.D.element], core_name = self.core_name()) @@ -216,7 +216,7 @@ def datatype_name_3x(self): datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( element_a = DataTypeNames[self.A.element], element_b = DataTypeNames[self.B.element], - element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_acc = DataTypeNames[self.accumulator_type()], element_c = DataTypeNames[self.C.element], element_d = DataTypeNames[self.D.element]) return datatype_name @@ -744,7 +744,7 @@ def __init__(self, operation_suffix = ''): cute::Shape, cute::Shape<${cluster_shape_m}, ${cluster_shape_n}, ${cluster_shape_k}>, ${stages}, - ${kernel_schedule} + ${kernel_schedule} >::CollectiveOp; // Gemm operator ${operation_name} @@ -817,8 +817,9 @@ def emit(self, operation): else: epilogue_functor = self.epilogue_functor.emit_declaration() # - element_a = DataTypeTag[operation.A.element] - element_b = DataTypeTag[operation.B.element] + # Cutlass3x complex kernels' ElementA(B) is a tuple in collective mainloop builder, e.g. cute::tuple, Transform : cute::identity / cute::conjugate. + element_a = DataTypeTag[operation.A.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.A.element])},{str(ComplexTransformTag3x[operation.A.complex_transform])}>" + element_b = DataTypeTag[operation.B.element] if not operation.is_complex() else f"cute::tuple<{str(DataTypeTag[operation.B.element])},{str(ComplexTransformTag3x[operation.B.complex_transform])}>" epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] values = { 'operation_name': operation.procedural_name(), diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 9f1045f389..0ac604e74c 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -967,6 +967,7 @@ def extended_name(self): def configuration_name(self): prefix = 'cutlass3x' + arch = self.arch opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] tbm = self.tile_description.tile_shape[0] tbn = self.tile_description.tile_shape[1] @@ -979,7 +980,7 @@ def configuration_name(self): kernel_schedule = KernelScheduleSuffixes[self.kernel_schedule] epilogue_schedule = EpilogueScheduleSuffixes[self.epilogue_schedule] - return f"{prefix}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}" + return f"{prefix}_sm{arch}_{opcode_class_name}_{self.extended_name()}_{tbm}x{tbn}x{tbk}_{cm}x{cn}x{ck}_{self.tile_description.stages}_align{alignment}{tile_scheduler}{kernel_schedule}{epilogue_schedule}" def procedural_name(self): return self.configuration_name() diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index a347091c47..710cad31ca 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -250,6 +250,12 @@ class ComplexTransform(enum.Enum): ComplexTransform.conj: 'cutlass::ComplexTransform::kConjugate', } +# Used for cutlass3x complex kernel collective mainloop builder instantiation +ComplexTransformTag3x = { + ComplexTransform.none: 'cute::identity', + ComplexTransform.conj: 'cute::conjugate', +} + # RealComplexBijection = [ (DataType.f16, DataType.cf16), diff --git a/test/python/cutlass/emit/pytorch.py b/test/python/cutlass/emit/pytorch.py index ac75dbb565..18388a76f9 100644 --- a/test/python/cutlass/emit/pytorch.py +++ b/test/python/cutlass/emit/pytorch.py @@ -124,7 +124,6 @@ def test_gemm(self): dtype = torch.float16 plan = cutlass.op.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor) - plan.activation = cutlass.epilogue.relu op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: @@ -132,7 +131,7 @@ def test_gemm(self): A, B, C, _ = _initialize(dtype, 1024, 256, 512) - D_ref = torch.nn.functional.relu(A @ B) + D_ref = A @ B D = mod.run(A, B) assert torch.allclose(D, D_ref) @@ -147,7 +146,7 @@ def test_gemm(self): alpha = 2.0 beta = -1.0 - D_ref = torch.nn.functional.relu((A @ B) * alpha + (beta * C)) + D_ref = (A @ B) * alpha + (beta * C) D = mod.run(A, B, C, alpha, beta) assert torch.allclose(D, D_ref) diff --git a/test/unit/conv/cache_testbed_output.h b/test/unit/conv/cache_testbed_output.h index 8c44302246..4f3981e83b 100644 --- a/test/unit/conv/cache_testbed_output.h +++ b/test/unit/conv/cache_testbed_output.h @@ -122,19 +122,15 @@ inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) struct CachedTestResult { uint32_t D; - uint32_t sum; - uint32_t sum_of_square; - uint32_t second_sum_of_square; // // Methods // - CachedTestResult(): D(), sum(), sum_of_square(), second_sum_of_square() { } + CachedTestResult(): D() + { } - CachedTestResult(uint32_t D): D(D), sum(), sum_of_square(), second_sum_of_square() { } - - CachedTestResult(uint32_t D, uint32_t sum, uint32_t sum_of_square, uint32_t second_sum_of_square): - D(D), sum(sum), sum_of_square(sum_of_square), second_sum_of_square(second_sum_of_square) { } + CachedTestResult(uint32_t D): D(D) + { } operator bool() const { return bool(D); @@ -262,6 +258,7 @@ inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { case cutlass::conv::Operator::kFprop: return "fprop"; case cutlass::conv::Operator::kDgrad: return "dgrad"; case cutlass::conv::Operator::kWgrad: return "wgrad"; + case cutlass::conv::Operator::kDeconv: return "deconv"; } return "conv_unknown"; } diff --git a/test/unit/conv/device/CMakeLists.txt b/test/unit/conv/device/CMakeLists.txt index 0ac101ece5..d3a6782f44 100644 --- a/test/unit/conv/device/CMakeLists.txt +++ b/test/unit/conv/device/CMakeLists.txt @@ -140,14 +140,19 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) conv2d_fprop_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu conv2d_dgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu conv2d_wgrad_implicit_gemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32_sm80.cu + deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu conv2d_fprop_with_broadcast_simt_sm80.cu + deconv2d_with_broadcast_simt_sm80.cu conv3d_fprop_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu conv3d_dgrad_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu conv3d_wgrad_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu + deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu conv3d_fprop_with_broadcast_simt_sm80.cu + deconv3d_with_broadcast_simt_sm80.cu + ) endif() @@ -176,6 +181,7 @@ cutlass_test_unit_add_executable( conv2d_fprop_with_broadcast_sm75.cu conv2d_fprop_with_reduction_sm75.cu + conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm75.cu ) @@ -209,6 +215,7 @@ if (CUTLASS_NVCC_MAX_ARCH GREATER_EQUAL 80) conv2d_strided_dgrad_implicit_gemm_swizzling4_sm80.cu # Conv3d + conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu conv3d_wgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu # Group Conv2d diff --git a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu index 240c579ae3..c75ebbee3b 100644 --- a/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu @@ -85,7 +85,7 @@ TEST(SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tens } //////////////////////////////////////////////////////////////////////////////// -#if 0 + TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32, 128x128_64x3_64x64x64) { @@ -116,7 +116,8 @@ TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_t cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 3, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; @@ -124,7 +125,6 @@ TEST(SM80_Device_Conv2d_Fprop_Precomputed_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_t /// Run all unit test sizes with device-level Conv2d instance EXPECT_TRUE(test::conv::device::TestAllConv2d()); } -#endif //////////////////////////////////////////////////////////////////////////////// #endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu b/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu index 4d0cb2f20d..944af8b6ee 100644 --- a/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/conv2d_fprop_with_broadcast_simt_sm80.cu @@ -81,7 +81,8 @@ TEST(SM80_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f32nhwc_f32nh cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv2dFprop = cutlass::conv::device::ImplicitGemmConvolution; @@ -103,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -void TestResidaulBlock() { +static void Conv2dFpropSM80TestResidaulBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -161,7 +162,7 @@ void TestResidaulBlock() { TEST(SM80_Device_Conv2d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - TestResidaulBlock(); + Conv2dFpropSM80TestResidaulBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv2d_testbed.h b/test/unit/conv/device/conv2d_testbed.h index 2ace470b5e..d957beb03d 100644 --- a/test/unit/conv/device/conv2d_testbed.h +++ b/test/unit/conv/device/conv2d_testbed.h @@ -153,7 +153,6 @@ class TestbedConv2d { else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } else if (dist_kind == cutlass::Distribution::Gaussian) { @@ -489,7 +488,8 @@ class TestbedConv2d { fname << "error_Conv2d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << ss_problem_size_text.str() << Conv2d::ThreadblockShape::kM << "x" << Conv2d::ThreadblockShape::kN << "x" @@ -635,8 +635,8 @@ bool TestAllConv2d( // // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -663,8 +663,8 @@ bool TestAllConv2d( // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} // Although strided dgrad works for all stride combinations, we are only going // to run strided dgrad for non-unity strides - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -718,8 +718,8 @@ bool TestAllConv2d( } // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { diff --git a/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/test/unit/conv/device/conv2d_with_broadcast_testbed.h index 52771a7a5a..278d447f80 100644 --- a/test/unit/conv/device/conv2d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -404,9 +404,9 @@ class TestbedConv2dWithBroadcast { // compute tensor Z and tensor T for (int n = 0; n < problem_size.N; ++n) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { ElementZ z{}; ElementT t{}; @@ -449,7 +449,8 @@ class TestbedConv2dWithBroadcast { fname << "error_Conv2d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << "nhwc_" << problem_size.N << "x" << problem_size.H << "x" @@ -602,8 +603,8 @@ bool TestAllConv2dWithBroadcast( // // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -613,8 +614,8 @@ bool TestAllConv2dWithBroadcast( #if 0 // relax restrictions on analytic strided dgrad // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { @@ -650,8 +651,8 @@ bool TestAllConv2dWithBroadcast( } // CUTLASS DGRAD's *strided* specialization does not support split-k mode - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { diff --git a/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu b/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu index 910e92e4b8..27ae274c22 100644 --- a/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu +++ b/test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu @@ -111,7 +111,8 @@ TEST(SM80_Device_Conv3d_Fprop_Optimized_ImplicitGemm_f16ndhwc_f16ndhwc_f32ndhwc_ cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; diff --git a/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu b/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu index e401d113ed..bc0dee0e03 100644 --- a/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu +++ b/test/unit/conv/device/conv3d_fprop_with_broadcast_simt_sm80.cu @@ -81,7 +81,8 @@ TEST(SM80_Device_Conv3d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f32ndhwc_f32n cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, 4, cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kAnalytic + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided >::Kernel; using Conv3dFprop = cutlass::conv::device::ImplicitGemmConvolution; @@ -103,7 +104,7 @@ template < template class UnaryOp, bool TestSplitK = true > -void TestResidaulBlock() { +static void Conv3dFpropSM80TestResidaulBlock() { using ElementA = float; using ElementB = float; using ElementC = float; @@ -161,7 +162,7 @@ void TestResidaulBlock() { TEST(SM80_Device_Conv3d_Fprop_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, 128x128_8x4_32x64x8) { // Resnet - TestResidaulBlock(); + Conv3dFpropSM80TestResidaulBlock(); } //////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/conv3d_testbed.h b/test/unit/conv/device/conv3d_testbed.h index bfbe892163..54bf936333 100644 --- a/test/unit/conv/device/conv3d_testbed.h +++ b/test/unit/conv/device/conv3d_testbed.h @@ -169,7 +169,7 @@ class TestbedConv3d { tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); initialize_tensor(tensor_A.host_view(), init_A, seed); - initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); initialize_tensor(tensor_C.host_view(), init_C, seed * 39); tensor_A.sync_device(); @@ -358,12 +358,12 @@ class TestbedConv3d { bool cached_result_loaded = false; CachedTestResult cached_test_result; - std::string conv2d_result_cache_name = + std::string conv3d_result_cache_name = std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; - + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { - CachedTestResultListing cached_results(conv2d_result_cache_name); + CachedTestResultListing cached_results(conv3d_result_cache_name); auto cached = cached_results.find(cached_test_key); @@ -376,7 +376,7 @@ class TestbedConv3d { if (!cached_result_loaded) { #if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED - + cutlass::reference::device::Conv3d< ElementA, LayoutA, @@ -426,15 +426,14 @@ class TestbedConv3d { cached_test_result.D = TensorHash(tensor_D_reference.host_view()); - CachedTestResultListing cached_results(conv2d_result_cache_name); + CachedTestResultListing cached_results(conv3d_result_cache_name); cached_results.append(cached_test_key, cached_test_result); - cached_results.write(conv2d_result_cache_name); + cached_results.write(conv3d_result_cache_name); } } // if (!cached_result_loaded) uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); - if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { passed = (tensor_D_hash == cached_test_result.D); @@ -456,7 +455,8 @@ class TestbedConv3d { fname << "error_Conv3d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << "ndhwc_" << problem_size.N << "x" << problem_size.D << "x" @@ -571,8 +571,8 @@ bool TestAllConv3d( // // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity) || (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == diff --git a/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/test/unit/conv/device/conv3d_with_broadcast_testbed.h index 93acfcabff..cc7c06f7da 100644 --- a/test/unit/conv/device/conv3d_with_broadcast_testbed.h +++ b/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -227,7 +227,6 @@ class TestbedConv3dWithBroadcast { initialize_tensor(tensor_B.host_view(), init_B, seed * 17); initialize_tensor(tensor_C.host_view(), init_C, seed * 39); initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); - for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { @@ -239,7 +238,6 @@ class TestbedConv3dWithBroadcast { } } } - tensor_A.sync_device(); tensor_B.sync_device(); tensor_C.sync_device(); @@ -407,10 +405,10 @@ class TestbedConv3dWithBroadcast { // compute tensor Z and tensor T for (int n = 0; n < problem_size.N; ++n) { - for (int o = 0; o < problem_size.Z; ++o) { - for (int p = 0; p < problem_size.P; ++p) { - for (int q = 0; q < problem_size.Q; ++q) { - for (int k = 0; k < problem_size.K; ++k) { + for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { ElementZ z{}; ElementT t{}; @@ -454,7 +452,8 @@ class TestbedConv3dWithBroadcast { fname << "error_Conv3d_ImplicitGemm_device_" << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : - (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) << "nnhwc_" << problem_size.N << "x" << problem_size.D << "x" @@ -563,8 +562,8 @@ bool TestAllConv3dWithBroadcast( // // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kUnity)) { if (!((conv_problem.stride_d == 1) && @@ -577,8 +576,8 @@ bool TestAllConv3dWithBroadcast( #if 0 // relax restrictions on analytic strided dgrad // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} - if ((ImplicitGemm::kConvolutionalOperator == - cutlass::conv::Operator::kDgrad) && + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == cutlass::conv::StrideSupport::kStrided)) { if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { diff --git a/test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu b/test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu new file mode 100644 index 0000000000..73a78d3330 --- /dev/null +++ b/test/unit/conv/device/deconv2d_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.cu @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/kernel/default_deconv2d.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv2d_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_8x4_32x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv2d instance + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2d< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv2d_Fprop_Optimized_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_8x4_64x32x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv2d instance + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2d< + ElementA, + cutlass::layout::TensorNHWC, + ElementB, + cutlass::layout::TensorNHWC, + ElementC, + cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu b/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu new file mode 100644 index 0000000000..7872f8a466 --- /dev/null +++ b/test/unit/conv/device/deconv2d_with_broadcast_simt_sm80.cu @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/conv/kernel/default_deconv2d_with_broadcast.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv2d_with_broadcast_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + + +TEST(SM80_Device_Deconv2d_With_Broadcast_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC, + ElementAccumulator, + ElementCompute, + ElementC, + ElementC, + 1, + cutlass::epilogue::thread::ReLu + >; + + /// Device-level Conv2d instance + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2dWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv2d instance + EXPECT_TRUE(test::conv::device::TestAllConv2dWithBroadcast()); +} + +// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv2d(X) + bias), residual)) +// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. +// This is because the activation needs to be applied to the fully accumulated output of the Conv2d op, +// which only the last thread block would have an access to, before applying BinaryOp. +// The epilogue functor in the last thread block would have to be given three inputs, namely +// partial outputs, bias, and residual, but this is not supported in the current interface. +// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. +template < + template class ActivationOp, + template class BinaryOp, + template class UnaryOp, + bool TestSplitK = true +> +static void Deconv2dSM80TestResidaulBlock() { + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = ElementC; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementD, + ElementAccumulator, + ElementCompute, + ElementC, + 1, + ActivationOp, + BinaryOp, + UnaryOp + >; + + using Deconv2dKernel = typename cutlass::conv::kernel::DefaultDeconv2dWithBroadcast< + ElementA, cutlass::layout::TensorNHWC, + ElementB, cutlass::layout::TensorNHWC, + ElementC, cutlass::layout::TensorNHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv2d = cutlass::conv::device::ImplicitGemmConvolution; + + struct ReferenceOp { + using OutputOp = typename Deconv2d::EpilogueOutputOp; + using ElementZ = typename OutputOp::ElementZ; + + ActivationOp activation; + BinaryOp binary_op; + UnaryOp unary_op; + + void operator()(ElementZ &Z, ElementZ&, ElementCompute conv2d, ElementCompute residual) { + Z = ElementZ(unary_op(binary_op(activation(conv2d), residual))); + } + }; + + bool passed = test::conv::device::TestAllConv2dWithBroadcast(); + EXPECT_TRUE(passed); +} + +TEST(SM80_Device_Deconv2d_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32, + 128x128_8x4_32x64x8) { + // Resnet + Deconv2dSM80TestResidaulBlock(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu b/test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu new file mode 100644 index 0000000000..929a515165 --- /dev/null +++ b/test/unit/conv/device/deconv3d_implicit_gemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32_sm80.cu @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv3d.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv3d_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv3d_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_8x4_32x64x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv3d instance + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3d< + ElementA, + cutlass::layout::TensorNDHWC, + ElementB, + cutlass::layout::TensorNDHWC, + ElementC, + cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kAnalytic, + cutlass::conv::StrideSupport::kStrided + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv3d instance + EXPECT_TRUE(test::conv::device::TestAllConv3d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_Deconv3d_Optimized_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_8x4_64x32x8) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementAccumulator = float; + using ElementCompute = float; + + + /// Device-level Conv3d instance + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3d< + ElementA, + cutlass::layout::TensorNDHWC, + ElementB, + cutlass::layout::TensorNDHWC, + ElementC, + cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<64, 32, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + cutlass::epilogue::thread::LinearCombination< + ElementC, + 1, + ElementAccumulator, + ElementCompute + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv3d instance + EXPECT_TRUE(test::conv::device::TestAllConv3d()); + +} + +//////////////////////////////////////////////////////////////////////////////// +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED diff --git a/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu b/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu new file mode 100644 index 0000000000..e0d0171f7f --- /dev/null +++ b/test/unit/conv/device/deconv3d_with_broadcast_simt_sm80.cu @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Implicit GEMM interface +*/ + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/conv/kernel/default_deconv3d_with_broadcast.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "conv3d_with_broadcast_testbed.h" + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +TEST(SM80_Device_Deconv3d_With_Broadcast_Optimized_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_32x2_64x64x32) { + + /// Conv operation element types for the Gemm equivalent (ImplicitGemm) + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationBiasElementwise< + ElementC, + ElementAccumulator, + ElementCompute, + ElementC, + ElementC, + 1, + cutlass::epilogue::thread::ReLu + >; + + /// Device-level Conv3d instance + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3dWithBroadcast< + ElementA, cutlass::layout::TensorNDHWC, + ElementB, cutlass::layout::TensorNDHWC, + ElementC, cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + /// Run all unit test sizes with device-level Conv3d instance + EXPECT_TRUE(test::conv::device::TestAllConv3dWithBroadcast()); +} + +// Test residual block fusion: UnaryOp(BinaryOp(ActivationOp(Conv3d(X) + bias), residual)) +// LinearCombinationResidualBlock does not support the split-k mode unless ActivationOp is Identity. +// This is because the activation needs to be applied to the fully accumulated output of the Conv3d op, +// which only the last thread block would have an access to, before applying BinaryOp. +// The epilogue functor in the last thread block would have to be given three inputs, namely +// partial outputs, bias, and residual, but this is not supported in the current interface. +// Set TestSplitK = false to skip split-k tests with non-trivial ActivationOp. +template < + template class ActivationOp, + template class BinaryOp, + template class UnaryOp, + bool TestSplitK = true +> +static void Deconv3dSM80TestResidaulBlock() { + using ElementA = float; + using ElementB = float; + using ElementC = float; + using ElementD = ElementC; + using ElementCompute = float; + using ElementAccumulator = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementD, + ElementAccumulator, + ElementCompute, + ElementC, + 1, + ActivationOp, + BinaryOp, + UnaryOp + >; + + using Deconv3dKernel = typename cutlass::conv::kernel::DefaultDeconv3dWithBroadcast< + ElementA, cutlass::layout::TensorNDHWC, + ElementB, cutlass::layout::TensorNDHWC, + ElementC, cutlass::layout::TensorNDHWC, + ElementAccumulator, + cutlass::arch::OpClassSimt, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 8>, + cutlass::gemm::GemmShape<32, 64, 8>, + cutlass::gemm::GemmShape<1, 1, 1>, + EpilogueOutputOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kUnity + >::Kernel; + + using Deconv3d = cutlass::conv::device::ImplicitGemmConvolution; + + struct ReferenceOp { + using OutputOp = typename Deconv3d::EpilogueOutputOp; + using ElementZ = typename OutputOp::ElementZ; + + ActivationOp activation; + BinaryOp binary_op; + UnaryOp unary_op; + + void operator()(ElementZ &Z, ElementZ&, ElementCompute conv3d, ElementCompute residual) { + Z = ElementZ(unary_op(binary_op(activation(conv3d), residual))); + } + }; + + bool passed = test::conv::device::TestAllConv3dWithBroadcast(); + EXPECT_TRUE(passed); +} + +TEST(SM80_Device_Deconv3d_With_Residual_Block_Plus_Analytic_ImplicitGemm_f32ndhwc_f32ndhwc_f32ndhwc_simt_f32, + 128x128_8x4_32x64x8) { + // Resnet + Deconv3dSM80TestResidaulBlock(); +} + +//////////////////////////////////////////////////////////////////////////////// + +#endif // CUTLASS_ARCH_MMA_SM80_SUPPORTED + +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/conv/device_3x/CMakeLists.txt b/test/unit/conv/device_3x/CMakeLists.txt index d6d8e2157c..dddeba6f11 100644 --- a/test/unit/conv/device_3x/CMakeLists.txt +++ b/test/unit/conv/device_3x/CMakeLists.txt @@ -26,6 +26,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + add_subdirectory(fprop) add_subdirectory(wgrad) add_subdirectory(dgrad) diff --git a/test/unit/conv/device_3x/testbed_conv.hpp b/test/unit/conv/device_3x/testbed_conv.hpp index 67501ea1b0..e22c2cfffc 100644 --- a/test/unit/conv/device_3x/testbed_conv.hpp +++ b/test/unit/conv/device_3x/testbed_conv.hpp @@ -53,7 +53,6 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/device/tensor_compare.h" - #include "conv_problem_sizes.hpp" #include "../cache_testbed_output.h" @@ -195,7 +194,8 @@ struct ConvTestbed { bool run( ProblemShape const& problem_shape, ElementScalar alpha = ElementScalar(1), - ElementScalar beta = ElementScalar(0)) { + ElementScalar beta = ElementScalar(0) + ) { // Waive test if insufficient CUDA device if (!sufficient()) { @@ -250,14 +250,16 @@ struct ConvTestbed { auto &fusion_args = args.epilogue.thread; - // some fused patterns have no linear combination + fusion_args.alpha = alpha; + fusion_args.beta = beta; + if constexpr (IsBiasEnabled) { fusion_args.bias_ptr = tensor_bias.data().get(); } // Clamp bound if constexpr (cute::is_same_v>) { - fusion_args.activation.lower_bound = ElementCompute{0}; + fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); } @@ -422,17 +424,11 @@ struct ConvTestbed { reference_impl.compute_reference(); } // Validate kernel against reference - passed = compare_reference( - mD_ref, mD_computed, mA, mB, mAlpha, - mBeta, mBias, - this->epsilon); + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); } #else // Validate kernel against reference - passed = compare_reference( - mD_ref, mD_computed, mA, mB, mAlpha, - mBeta, mBias, - this->epsilon); + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); #endif EXPECT_TRUE(passed); @@ -445,8 +441,7 @@ struct ConvTestbed { class EngineB, class LayoutB, class EngineAlpha, class LayoutAlpha, class EngineBeta, class LayoutBeta, - class EngineBias, class LayoutBias - > + class EngineBias, class LayoutBias> static constexpr bool compare_reference( cute::Tensor const& reference, @@ -503,7 +498,6 @@ struct ConvTestbed { printf("[%ld]: bias = %f\n", i, float(tensor_bias(i))); } } - for (size_t i = 0; i < size_t(size(reference)); ++i) { printf("[%ld]: ref = %f, computed = %f\n", i, float(reference(i)), float(computed(i))); } diff --git a/test/unit/core/CMakeLists.txt b/test/unit/core/CMakeLists.txt index 4060364966..d0e10a7b47 100644 --- a/test/unit/core/CMakeLists.txt +++ b/test/unit/core/CMakeLists.txt @@ -45,30 +45,3 @@ cutlass_test_unit_add_executable( fast_numeric_conversion.cu functional.cu ) - -# -# CUTLASS 3x increases the host compiler requirements to C++17. However, there are -# certain existing integrations that will benefit from maintaining C++11 compatibility. -# -# This requirement only applies to select .h files which are explicitly annotated. It -# does not apply to any .hpp file. -# -# `cutlass_test_unit_core_cpp11` enforces the C++11 requirement. -# - -set(CMAKE_CUDA_STANDARD 11) -set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -add_executable( - cutlass_test_unit_core_cpp11 - - cpp11.cu -) - -if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - target_compile_options( - cutlass_test_unit_core_cpp11 - PRIVATE - $<$:-Xcompiler -Werror> - ) -endif() diff --git a/test/unit/core/cpp11.cu b/test/unit/core/cpp11.cu deleted file mode 100644 index 553b031c80..0000000000 --- a/test/unit/core/cpp11.cu +++ /dev/null @@ -1,87 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain - existing integrations of CUTLASS require C++11 host compilers. - - Until this requirement can be lifted, certain headers with this annotation are required - to be remain consistent with C++11 syntax. - - C++11 compatibility is enforced by this unit test: `cutlass_test_unit_core_cpp11`. -*/ - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include -#include - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -#if (201700L <= __cplusplus ) -#error "This file and all of its includes must be compilable as C++11." -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - -int main() { - return 0; -} - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index dd72b840bf..75e12bdf14 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -108,8 +108,8 @@ __global__ void convert_with_scale_factor( ///////////////////////////////////////////////////////////////////////////////////////////////// -template -void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[]) { +template +void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[], const int range = 4, const int offset = 0) { const int kN = Count; dim3 grid(1, 1); @@ -124,7 +124,7 @@ void run_test_with_scalefactor(const char dest_name[], const char source_name[], for (int i = 0; i < kN; ++i) { - source_ref.at({0, i}) = Source(i % Range); + source_ref.at({0, i}) = Source(i % range + offset); } for (int i = 0; i < kN; ++i) { @@ -144,10 +144,12 @@ void run_test_with_scalefactor(const char dest_name[], const char source_name[], for (int i = 0; i < kN; ++i) { float ref = float(source_ref.at({0, i})) / float(scale_factor_ref.at({0, i})); - EXPECT_TRUE(float(destination_ref.at({0, i})) == ref) - << "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) - << ", Source type: " << source_name << " " << float(source_ref.at({0, i})) - << ", Count: " << Count; + bool pass = float(destination_ref.at({0, i})) == ref; + EXPECT_TRUE(pass) + << "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) << std::endl + << ", Source type: " << source_name << " " << float(source_ref.at({0, i})) << std::endl + << ", Scalefactor type: " << source_name << " " << float(scale_factor_ref.at({0, i})) << std::endl + << ", idx: " << i << std::endl; } } diff --git a/test/unit/cute/CMakeLists.txt b/test/unit/cute/CMakeLists.txt index d001f18ef2..601c0c0d96 100644 --- a/test/unit/cute/CMakeLists.txt +++ b/test/unit/cute/CMakeLists.txt @@ -28,6 +28,7 @@ add_subdirectory(core) add_subdirectory(volta) +add_subdirectory(turing) add_subdirectory(ampere) add_subdirectory(hopper) add_subdirectory(layout) @@ -39,6 +40,7 @@ add_custom_target( cutlass_test_unit_cute_layout cutlass_test_unit_cute_core cutlass_test_unit_cute_volta + cutlass_test_unit_cute_turing cutlass_test_unit_cute_ampere cutlass_test_unit_cute_hopper cutlass_test_unit_cute_msvc_compilation @@ -51,6 +53,7 @@ add_custom_target( test_unit_cute_core test_unit_cute_volta test_unit_cute_ampere + test_unit_cute_turing test_unit_cute_hopper test_unit_cute_msvc_compilation ) diff --git a/test/unit/cute/ampere/CMakeLists.txt b/test/unit/cute/ampere/CMakeLists.txt index d05a73c08b..fd701de656 100644 --- a/test/unit/cute/ampere/CMakeLists.txt +++ b/test/unit/cute/ampere/CMakeLists.txt @@ -30,6 +30,7 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_ampere cp_async.cu ldsm.cu + cooperative_gemm.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/cute/ampere/cooperative_gemm.cu b/test/unit/cute/ampere/cooperative_gemm.cu new file mode 100644 index 0000000000..2fcd01205d --- /dev/null +++ b/test/unit/cute/ampere/cooperative_gemm.cu @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +TEST(SM80_CuTe_Ampere, CooperativeGemm1_Half_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm2_Double_MMA) { + using value_type = double; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm3_Half_MMA_CustomSmemLayouts) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 128; + constexpr uint32_t k = 128; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` + >; + + using smem_a_atom_layout_t = Layout, Stride< _1,_64>>; + using smem_b_atom_layout_t = Layout, Stride<_32, _1>>; + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm4_Half_MMA_SwizzledSmemLayouts) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 128; + constexpr uint32_t k = 128; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32, _32, _16> // 32x32x16 MMA for LDSM, 1x2x1 value group` + >; + + // RowMajor + using smem_rowmajor_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + // ColMajor + using smem_colmajor_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using smem_a_atom_layout_t = smem_rowmajor_atom_layout_t; + using smem_b_atom_layout_t = smem_colmajor_atom_layout_t; + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + + using smem_a_atom_layout_t = smem_a_atom_layout_t; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + + using smem_b_atom_layout_t = smem_b_atom_layout_t; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + + using smem_c_atom_layout_t = smem_c_atom_layout_t; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm5_Double_MMA_SwizzledSmemLayouts) { + using value_type = double; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 64; + constexpr uint32_t k = 16; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA, // Atom + Layout>, // Atom layout + Tile, Stride<_2, _1>>, // 32x32x4 MMA with perm for load vectorization + Layout, Stride<_2, _1>>, + Underscore>>; + + using smem_a_atom_layout_t = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using smem_b_atom_layout_t = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + + using smem_a_atom_layout_t = smem_a_atom_layout_t; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + using smem_b_atom_layout_t = smem_b_atom_layout_t; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + using smem_c_atom_layout_t = smem_c_atom_layout_t; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<128>, // B + AutoVectorizingCopyWithAssumedAlignment<128>, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm6_MixedPrecisionFP16FP32_MMA) { + using TA = cutlass::half_t; + using TB = cutlass::half_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm7_MixedPrecisionBF16FP32_MMA) { + using TA = cutlass::bfloat16_t; + using TB = cutlass::bfloat16_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM80_CuTe_Ampere, CooperativeGemm8_MixedPrecisionTF32FP32_MMA) { + using TA = cutlass::tfloat32_t; + using TB = cutlass::tfloat32_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} diff --git a/test/unit/cute/cooperative_gemm_common.hpp b/test/unit/cute/cooperative_gemm_common.hpp new file mode 100644 index 0000000000..9f7f694619 --- /dev/null +++ b/test/unit/cute/cooperative_gemm_common.hpp @@ -0,0 +1,414 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(TA const* a, + TB const* b, + TC* c, + TC* c_out, + Alpha const alpha, + Beta const beta, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), ALayout{}); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), BLayout{}); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), CLayout{}); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), CLayout{}); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(SMemALayout {})), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(SMemBLayout {})), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), SMemALayout{}); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), SMemBLayout{}); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), SMemCLayout{}); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + TiledMma tiled_mma; + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +void test_cooperative_gemm(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + using gmem_a_layout_t = ALayout; + using gmem_b_layout_t = BLayout; + using gmem_c_layout_t = CLayout; + + using smem_a_layout_t = SMemALayout; + using smem_b_layout_t = SMemBLayout; + using smem_c_layout_t = SMemCLayout; + + static_assert(size<0>(gmem_a_layout_t{}) == size<0>(gmem_c_layout_t{})); // AM == CM + static_assert(size<0>(gmem_b_layout_t{}) == size<1>(gmem_c_layout_t{})); // BN == CN + static_assert(size<1>(gmem_a_layout_t{}) == size<1>(gmem_b_layout_t{})); // AK == BK + + static_assert(size<0>(smem_a_layout_t{}) == size<0>(smem_c_layout_t{})); // AM == CM + static_assert(size<0>(smem_b_layout_t{}) == size<1>(smem_c_layout_t{})); // BN == CN + static_assert(size<1>(smem_a_layout_t{}) == size<1>(smem_b_layout_t{})); // AK == BK + + static_assert(cute::size(gmem_a_layout_t {}) == cute::size(smem_a_layout_t {})); + static_assert(cute::size(gmem_b_layout_t {}) == cute::size(smem_b_layout_t {})); + static_assert(cute::size(gmem_c_layout_t {}) == cute::size(smem_c_layout_t {})); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout_t{}); print("\n"); + print(" "); print("smem: "); print(smem_layout_t{}); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + thrust::host_vector h_a(cosize(gmem_a_layout_t{})); + thrust::host_vector h_b(cosize(gmem_b_layout_t{})); + thrust::host_vector h_c(cosize(gmem_c_layout_t{})); + thrust::host_vector h_c_out(cosize(gmem_c_layout_t{})); + + auto h_a_tensor = make_tensor(h_a.data(), gmem_a_layout_t{}); + auto h_b_tensor = make_tensor(h_b.data(), gmem_b_layout_t{}); + auto h_c_tensor = make_tensor(h_c.data(), gmem_c_layout_t{}); + size_t max_size = std::max({static_cast(size(gmem_a_layout_t {})), + static_cast(size(gmem_b_layout_t {})), + static_cast(size(gmem_c_layout_t {}))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(gmem_a_layout_t{})) { + h_a_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); + } + if(i < size(gmem_b_layout_t{})) { + h_b_tensor(i) = static_cast(di / size(gmem_a_layout_t{})); + } + if(i < size(gmem_c_layout_t{})) { + h_c_tensor(i) = static_cast((di*di) / size(gmem_a_layout_t{})); + } + } + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + const size_t shared_memory_size = + (sizeof(TA) * h_a.size()) + (sizeof(TB) * h_b.size()) + (sizeof(TC) * h_c.size()); + auto kernel = cooperative_gemm_kernel< + gmem_a_layout_t, gmem_b_layout_t, gmem_c_layout_t, + smem_a_layout_t, smem_b_layout_t, smem_c_layout_t, + SmemCopyOpA, SmemCopyOpB, SmemCopyOpC, + ThreadBlockSize, TiledMma, CopyMaxVecBits, + TA, TB, TC, decltype(alpha), decltype(beta), + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform + >; + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform + ); + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + thrust::host_vector h_c_ref(h_c.size(), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), gmem_c_layout_t{}); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + h_c_out = d_c_out; + auto h_c_out_tensor = make_tensor(h_c_out.data(), gmem_c_layout_t{}); + for (int i = 0; i < size(h_c_ref_tensor); i++) { + double h_c_ref_i = h_c_ref_tensor(i); + double h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + using smem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + using smem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using smem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + test_cooperative_gemm>, + AutoVectorizingCopyWithAssumedAlignment>, + AutoVectorizingCopyWithAssumedAlignment>, + ThreadBlockSize, + TiledMMAType, + CopyMaxVecBits, + TA, + TB, + TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + test_cooperative_gemm_col_major_layout, T, T, T>( + a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + using smem_a_atom_layout_t = SMemAAtomLayout; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + + using smem_b_atom_layout_t = SMemBAtomLayout; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + + using smem_c_atom_layout_t = SMemCAtomLayout; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm>, + AutoVectorizingCopyWithAssumedAlignment>, + AutoVectorizingCopyWithAssumedAlignment>, + ThreadBlockSize, + TiledMMAType, + CopyMaxVecBits, + TA, + TB, + TC>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} + +template +void test_cooperative_gemm_col_major_layout(ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) +{ + test_cooperative_gemm_col_major_layout, + T, + T, + T>(a_load_transform, b_load_transform, c_load_transform, c_store_transform); +} diff --git a/test/unit/cute/core/complement.cpp b/test/unit/cute/core/complement.cpp index 460fdedef8..cba486f69d 100644 --- a/test/unit/cute/core/complement.cpp +++ b/test/unit/cute/core/complement.cpp @@ -35,22 +35,22 @@ #include -template +template void -test_complement(Layout const& layout, CoSizeHi const& cosize_hi) +test_complement(Layout const& layout, CoTarget const& cotarget) { using namespace cute; - auto result = complement(layout, cosize_hi); + auto result = complement(layout, cotarget); - CUTLASS_TRACE_HOST("complement(" << layout << ", " << cosize_hi << ") => " << result); + CUTLASS_TRACE_HOST("complement(" << layout << ", " << cotarget << ") => " << result); auto completed = make_layout(layout, result); // Lower-bound on the codomain size of the layout ++ complement (1) - EXPECT_GE(cosize(completed), cosize_hi); + EXPECT_GE(cosize(completed), size(cotarget)); // Upper-bound on the codomain size of the complement (2) - EXPECT_LE(cosize(result), cute::round_up(cosize_hi, cosize(layout))); + EXPECT_LE(cosize(result), cute::round_up(size(cotarget), cosize(layout))); // Post-condition on the codomain of the complement for (int i = 1; i < size(result); ++i) { @@ -62,9 +62,9 @@ test_complement(Layout const& layout, CoSizeHi const& cosize_hi) // Other observations EXPECT_LE(size(result), cosize(result)); // As a result of the ordered condition (3) - EXPECT_GE(size(result), cosize_hi / size(filter(layout))); + EXPECT_GE(size(result), size(cotarget) / size(filter(layout))); EXPECT_LE(cosize(completed), cosize(result) + cosize(layout)); - EXPECT_GE(cosize(result), cosize_hi / size(filter(layout))); + EXPECT_GE(cosize(result), size(cotarget) / size(filter(layout))); if constexpr (is_static::value) { // If we can apply complement again EXPECT_EQ(size(complement(completed)), 1); // There's no more codomain left over } @@ -90,6 +90,8 @@ TEST(CuTe_core, Complement) test_complement(layout); test_complement(layout, Int<2>{}); + test_complement(layout, Int<5>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -97,6 +99,8 @@ TEST(CuTe_core, Complement) test_complement(layout); test_complement(layout, Int<2>{}); + test_complement(layout, Int<5>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -105,6 +109,8 @@ TEST(CuTe_core, Complement) test_complement(layout, Int<1>{}); test_complement(layout, Int<2>{}); test_complement(layout, Int<8>{}); + test_complement(layout, Int<5>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -130,6 +136,7 @@ TEST(CuTe_core, Complement) test_complement(layout); test_complement(layout, Int<16>{}); test_complement(layout, Int<19>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -138,6 +145,7 @@ TEST(CuTe_core, Complement) test_complement(layout, Int<1>{}); test_complement(layout); test_complement(layout, Int<17>{}); + test_complement(layout, make_shape(Int<2>{}, 2)); } { @@ -193,8 +201,8 @@ TEST(CuTe_core, Complement) // Fails due to non-injective layout // { - // auto layout = make_layout(Shape,Shape<_2, _2>>{}, - // Stride,Stride<_8,_4>>{}); + // auto layout = make_layout(Shape ,Shape <_2,_2>>{}, + // Stride,Stride<_8,_4>>{}); // test_complement(layout); // } @@ -289,4 +297,11 @@ TEST(CuTe_core, Complement) test_complement(layout); } + + { + auto layout = make_layout(Int<64>{}); + + test_complement(layout, make_shape(Int<32>{}, Int<4>{}, Int<4>{})); + test_complement(layout, make_shape(Int<32>{}, Int<4>{}, 4)); + } } diff --git a/test/unit/cute/core/composition.cpp b/test/unit/cute/core/composition.cpp index 8f50ba5e8f..8e043f89e7 100644 --- a/test/unit/cute/core/composition.cpp +++ b/test/unit/cute/core/composition.cpp @@ -212,13 +212,12 @@ TEST(CuTe_core, Composition) test_composition(a, b); } - // FAILS due to b not "dividing into" a properly - //{ - // auto a = make_layout(Shape<_4,_3>{}); - // auto b = make_layout(Shape<_6>{}); + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_6>{}); - // test_composition(a, b); - //} + test_composition(a, b); + } { auto a = make_layout(Shape<_4,_3>{}); @@ -234,13 +233,12 @@ TEST(CuTe_core, Composition) test_composition(a, b); } - // FAILS due to b not "dividing into" a properly - //{ - // auto a = make_layout(Shape<_4,_3>{}); - // auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); + { + auto a = make_layout(Shape<_4,_3>{}); + auto b = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); - // test_composition(a, b); - //} + test_composition(a, b); + } { auto a = make_layout(Shape<_4,_3>{}, Stride<_3,_1>{}); @@ -523,4 +521,21 @@ TEST(CuTe_core, Composition) test_composition(a, b); } + CUTLASS_TRACE_HOST("-------------------------------"); + CUTLASS_TRACE_HOST("BETA: Tuple strides" ); + CUTLASS_TRACE_HOST("-------------------------------"); + + { + auto a = make_layout(Shape<_4,_4>{}, Stride<_4,_1>{}); + auto b = make_layout(Shape<_4,_4>{}, Stride,E<0>>{}); + + test_composition(a, b); + } + + { + auto a = make_layout(Shape<_4,Shape<_2,_3>>{}, Stride<_6,Stride<_3,_1>>{}); + auto b = make_layout(Shape<_2,_4>{}, Stride,E<0>>{}); + + test_composition(a, b); + } } diff --git a/test/unit/cute/core/logical_divide.cpp b/test/unit/cute/core/logical_divide.cpp index 6c1e2f291e..061fd5487b 100644 --- a/test/unit/cute/core/logical_divide.cpp +++ b/test/unit/cute/core/logical_divide.cpp @@ -227,27 +227,42 @@ TEST(CuTe_core, Logical_divide) ASSERT_TRUE(decltype(stride<1>(result) == Int<48>{})::value); } - // DISALLOWED - //{ - //auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); - //auto tile = Layout<_32>{}; + { + auto layout = make_layout(make_shape(Int<32>{}, Int<4>{}, 4)); + auto tile = Layout<_64>{}; + + test_logical_divide(layout, tile); + + // Enforcement of result + auto result = logical_divide(layout, tile); + ASSERT_TRUE(bool( shape(result) == make_shape (_64{}, make_shape ( _2{}, 4)))); + ASSERT_TRUE(bool(stride(result) == make_stride( _1{}, make_stride(_64{},_128{})))); + } - //test_logical_divide(layout, tile); - //} - //{ - //auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); - //auto tile = Layout<_32,_2>{}; + // + // ALLOWED, but dangerous due to the dynamic lhs shapes + // Consider disallowing... + // - //CUTLASS_TRACE_HOST("complement: " << complement(tile, size(layout))); - //test_logical_divide(layout, tile); - //} + { + auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); + auto tile = Layout<_32>{}; - //{ - //auto layout = make_layout(make_shape(16,4,3), make_stride(1,512,0)); - //auto tile = Layout<_32>{}; + test_logical_divide(layout, tile); + } - //CUTLASS_TRACE_HOST("complement: " << complement(tile, size(layout))); - //test_logical_divide(layout, tile); - //} + { + auto layout = make_layout(make_shape(128,4,3), make_stride(1,512,0)); + auto tile = Layout<_32,_2>{}; + + test_logical_divide(layout, tile); + } + + { + auto layout = make_layout(make_shape(16,4,3), make_stride(1,512,0)); + auto tile = Layout<_32>{}; + + test_logical_divide(layout, tile); + } } diff --git a/test/unit/cute/hopper/CMakeLists.txt b/test/unit/cute/hopper/CMakeLists.txt index f05d86b9c3..0b6db66f22 100644 --- a/test/unit/cute/hopper/CMakeLists.txt +++ b/test/unit/cute/hopper/CMakeLists.txt @@ -56,6 +56,11 @@ cutlass_test_unit_add_executable( tma_load.cu ) +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_hopper_tma_mcast_load + tma_mcast_load.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_cute_hopper_tma_store tma_store.cu diff --git a/test/unit/cute/hopper/tma_load.cu b/test/unit/cute/hopper/tma_load.cu index d171d36e61..0105d35144 100644 --- a/test/unit/cute/hopper/tma_load.cu +++ b/test/unit/cute/hopper/tma_load.cu @@ -44,7 +44,6 @@ test_tma_load(GMEM_Layout const& gmem_layout, SMEM_Layout const& smem_layout, CTA_Tile const& cta_tile) { - using namespace cute; return test_tma_load(SM90_TMA_LOAD{}, gmem_layout, smem_layout, cta_tile); } @@ -53,7 +52,6 @@ auto test_tma_load(GMEM_Layout const& gmem_layout, SMEM_Layout const& smem_layout) { - using namespace cute; return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); } diff --git a/test/unit/cute/hopper/tma_mcast_load.cu b/test/unit/cute/hopper/tma_mcast_load.cu new file mode 100644 index 0000000000..9a330716c4 --- /dev/null +++ b/test/unit/cute/hopper/tma_mcast_load.cu @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include "../hopper/tma_mcast_load_testbed.hpp" + +using namespace cute; +using namespace cutlass::test; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template > +auto +test_tma_load(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile, + Cluster_Size const& cluster_size = {}) +{ + return test_tma_load(SM90_TMA_LOAD_MULTICAST{}, gmem_layout, smem_layout, cta_tile, cluster_size); +} + +template +auto +test_tma_load(GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) +{ + return test_tma_load(gmem_layout, smem_layout, product_each(shape(smem_layout))); +} + +TEST(SM90_CuTe_Hopper, Tma_Load_32x32_Col_MCast) +{ + Layout smem_layout = Layout, Stride<_1,_32>>{}; + { + Layout gmem_layout = make_layout(make_shape(32,32), GenColMajor{}); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + test_tma_load< float>(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), Int<2>{}); + + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), 2); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), 2); + test_tma_load< float>(gmem_layout, smem_layout, shape(smem_layout), 2); + test_tma_load(gmem_layout, smem_layout, shape(smem_layout), 2); + } +} + +#endif diff --git a/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp new file mode 100644 index 0000000000..2fb88de50d --- /dev/null +++ b/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include +#include +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, + CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + +#if 1 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Prepare the TMA_LOAD + // + + Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) + Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); + +#if 1 + if (thread0()) { + print("sA_x : "); print(sA_x); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // TMA Multicast Masks -- Get a mask of the active ctas in each TMA + // + + + int elected_cta_rank = 0; + bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); + bool elect_one_thr = cute::elect_one_sync(); + + uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); + +#if 1 + if (thread0()) { + print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Perform the TMA_LOAD + // + + if (elect_one_thr) { + // Initialize TMA barrier + cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); + } + int tma_phase_bit = 0; + // Ensures all CTAs in the Cluster have initialized + __syncthreads(); + cute::cluster_sync(); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + + if (elect_one_thr) + { + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value + cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); + tma_phase_bit ^= 1; + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (elect_one_cta && elect_one_thr) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + cute::cluster_sync(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); + //print(tma); + + // Launch + + dim3 dimBlock(32); + dim3 dimCluster(size(cluster_size)); + dim3 dimGrid = dimCluster; + int smem_size = sizeof(SharedStorage); + + void* kernel_ptr = (void*) &tma_test_device_cute; + + cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, + kernel_ptr, + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast(raw_pointer_cast(d_out.data())), + gmem_layout, + smem_layout, + tma, cta_tiler, cluster_size); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/test/unit/cute/turing/CMakeLists.txt b/test/unit/cute/turing/CMakeLists.txt new file mode 100644 index 0000000000..ac8a0487b3 --- /dev/null +++ b/test/unit/cute/turing/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_test_unit_add_executable( + cutlass_test_unit_cute_turing + cooperative_gemm.cu +) diff --git a/test/unit/cute/turing/cooperative_gemm.cu b/test/unit/cute/turing/cooperative_gemm.cu new file mode 100644 index 0000000000..14ea967074 --- /dev/null +++ b/test/unit/cute/turing/cooperative_gemm.cu @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +TEST(SM75_CuTe_Turing, CooperativeGemm1_MixedPrecisionFP16FP32_MMA) { + using TA = cutlass::half_t; + using TB = cutlass::half_t; + using TC = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 64; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} diff --git a/test/unit/cute/volta/CMakeLists.txt b/test/unit/cute/volta/CMakeLists.txt index 0777f5bfaf..d6688aa30d 100644 --- a/test/unit/cute/volta/CMakeLists.txt +++ b/test/unit/cute/volta/CMakeLists.txt @@ -30,4 +30,5 @@ cutlass_test_unit_add_executable( cutlass_test_unit_cute_volta vectorization_auto.cu cooperative_copy.cu + cooperative_gemm.cu ) diff --git a/test/unit/cute/volta/cooperative_copy.cu b/test/unit/cute/volta/cooperative_copy.cu index c1ffe34c1b..2fc80b366a 100644 --- a/test/unit/cute/volta/cooperative_copy.cu +++ b/test/unit/cute/volta/cooperative_copy.cu @@ -263,6 +263,21 @@ TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefault1D) value_type>(); } +TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefault1DFallback) +{ + using value_type = float; + constexpr uint32_t count = 99; + using gmem_layout_t = decltype(make_layout(make_shape(Int{}))); + using smem_layout_t = decltype(make_layout(make_shape(Int{}))); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(); +} + TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2D) { using value_type = float; @@ -279,6 +294,22 @@ TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2D) value_type>(); } +TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2DFallback) +{ + using value_type = float; + constexpr uint32_t x = 37; + constexpr uint32_t y = 37; + using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + constexpr uint32_t thread_block_size = 64; + test_cooperative_copy_default(); +} + TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2DCustomStride) { using value_type = float; @@ -312,6 +343,23 @@ TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG3D) value_type>(); } +TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG3DFallback) +{ + using value_type = cute::half_t; + constexpr uint32_t x = 44; + constexpr uint32_t y = 24; + constexpr uint32_t z = 14; + using gmem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); + using smem_layout_t = decltype(make_layout(make_shape(Int{}, Int{}, Int{}))); + constexpr uint32_t thread_block_size = 128; + test_cooperative_copy_default(); +} + TYPED_TEST(SM70_CuTe_Volta, CooperativeCopyDefaultGSSG2Dto3D) { using value_type = double; diff --git a/test/unit/cute/volta/cooperative_gemm.cu b/test/unit/cute/volta/cooperative_gemm.cu new file mode 100644 index 0000000000..e8deb8b611 --- /dev/null +++ b/test/unit/cute/volta/cooperative_gemm.cu @@ -0,0 +1,421 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include + +#include "../cooperative_gemm_common.hpp" + +using namespace cute; + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA) { + using value_type = float; + + constexpr uint32_t m = 64; + constexpr uint32_t n = 32; + constexpr uint32_t k = 16; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication) { + using value_type = float; + + constexpr uint32_t m = 88; + constexpr uint32_t n = 20; + constexpr uint32_t k = 12; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication2) { + using value_type = float; + + constexpr uint32_t m = 88; + constexpr uint32_t n = 36; + constexpr uint32_t k = 24; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm1_FloatFMA_Predication3) { + using value_type = float; + + constexpr uint32_t m = 67; + constexpr uint32_t n = 13; + constexpr uint32_t k = 11; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm2_DoubleFMA) { + using value_type = double; + + constexpr uint32_t m = 16; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = + TiledMMA< + MMA_Atom>, + Layout> + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm3_Float_FMA_CustomPermutationMNK) { + using value_type = float; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 256; + + using tiled_mma_t = TiledMMA< + MMA_Atom< + UniversalFMA + >, + Layout< + Shape<_16, _16, _1> + >, + Tile< + Layout< + Shape<_16,_2>, Stride<_2,_1> + >, // 32x32x1 MMA with perm for load vectorization + Layout< + Shape<_16,_2>, Stride<_2,_1> + >, + Underscore + > + >; + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm4_Half_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using smem_a_atom_layout_t = typename tiled_mma_t::AtomLayoutB_TV; + using smem_b_atom_layout_t = typename tiled_mma_t::AtomLayoutA_TV; + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int {}, Int {}))); + + test_cooperative_gemm_col_major_layout(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<128>, // B + AutoVectorizingCopyWithAssumedAlignment<128>, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm5_Half_MMA_Predicated) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 31; + constexpr uint32_t n = 27; + constexpr uint32_t k = 17; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + using smem_a_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + using smem_b_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenColMajor{})); + using smem_c_layout_t = decltype(make_layout(make_shape(Int{}, Int{}))); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<16>, // B + AutoVectorizingCopyWithAssumedAlignment<16>, // C + thread_block_size, + tiled_mma_t, + 16, + value_type, + value_type, + value_type>(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm6_Half_MAA_SwizzledSmemLayouts) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 128; + constexpr uint32_t n = 128; + constexpr uint32_t k = 64; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + using smem_a_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using smem_b_atom_layout_t = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using smem_c_atom_layout_t = decltype(make_layout(make_shape(Int{}, Int{}), GenRowMajor{})); + + using gmem_a_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + using gmem_b_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenColMajor{})); + using gmem_c_layout_t = decltype(make_layout(make_shape(Int {}, Int {}), GenRowMajor{})); + + using smem_a_atom_layout_t = smem_a_atom_layout_t; + using smem_a_layout_t = decltype(tile_to_shape( + smem_a_atom_layout_t{}, + make_shape(shape<0>(gmem_a_layout_t{}), shape<1>(gmem_a_layout_t{}))) + ); + + // Transposed + using smem_b_atom_layout_t = smem_b_atom_layout_t; + using smem_b_layout_t = decltype(tile_to_shape( + smem_b_atom_layout_t{}, + make_shape(shape<0>(gmem_b_layout_t{}), shape<1>(gmem_b_layout_t{}))) + ); + + using smem_c_atom_layout_t = smem_c_atom_layout_t; + using smem_c_layout_t = decltype(tile_to_shape( + smem_c_atom_layout_t{}, + make_shape(shape<0>(gmem_c_layout_t{}), shape<1>(gmem_c_layout_t{}))) + ); + + test_cooperative_gemm, // A + AutoVectorizingCopyWithAssumedAlignment<128>, // B + AutoVectorizingCopyWithAssumedAlignment<128>, // C + thread_block_size, + tiled_mma_t, + 128, + value_type, + value_type, + value_type>(); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_FMA) { + using TA = float; + using TB = float; + using TC = double; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom>, + Layout> + >; + + auto aload = cute::negate {}; + auto bload = cute::negate {}; + auto cload = cute::negate {}; + auto cstore = cute::negate {}; + + test_cooperative_gemm_col_major_layout( + aload, bload, cload, cstore); +} + +TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformNegate_MMA) { + using value_type = cutlass::half_t; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom, + Layout> + >; + + auto aload = cute::negate {}; + auto bload = cute::negate {}; + auto cload = cute::negate {}; + auto cstore = cute::negate {}; + + test_cooperative_gemm_col_major_layout( + aload, bload, cload, cstore); +} + +template +struct increment_by_x { + ConstantType x; + + template + CUTE_HOST_DEVICE constexpr + T operator()(const T& arg) const { + return arg + x; + } +}; + +template +struct convert_to { + CUTE_HOST_DEVICE constexpr + To operator()(const From& arg) const { + return static_cast(arg); + } +}; + +TEST(SM70_CuTe_Volta, CooperativeGemm7_TransformCustomOp_FMA) { + using TA = float; + using TB = float; + using TC = double; + + constexpr uint32_t m = 32; + constexpr uint32_t n = 32; + constexpr uint32_t k = 32; + + constexpr uint32_t thread_block_size = 128; + + using tiled_mma_t = TiledMMA< + MMA_Atom>, + Layout> + >; + + auto aload = increment_by_x{1.111f}; + auto bload = convert_to {}; + auto cload = cute::negate {}; + auto cstore = cute::negate {}; + + test_cooperative_gemm_col_major_layout( + aload, bload, cload, cstore); +} diff --git a/test/unit/epilogue/thread/linear_combination_planar_complex.cu b/test/unit/epilogue/thread/linear_combination_planar_complex.cu index c950c12f7c..6cbc9589df 100644 --- a/test/unit/epilogue/thread/linear_combination_planar_complex.cu +++ b/test/unit/epilogue/thread/linear_combination_planar_complex.cu @@ -183,7 +183,7 @@ TEST(Epilogue_thread_linear_combination_planar_complex, f16_f32) { source.imag[i] = ElementOutput(((i * 5 + 2) % 9) - 4); } - cutlass::ArrayPlanarComplex destination = linear_combination_op(accum, source); + cutlass::ArrayPlanarComplex destination{ linear_combination_op(accum, source) }; // Verify each result for (int i = 0; i < kCount; ++i) { diff --git a/test/unit/epilogue/threadblock/testbed.h b/test/unit/epilogue/threadblock/testbed.h index eadda47006..b773d27cbf 100644 --- a/test/unit/epilogue/threadblock/testbed.h +++ b/test/unit/epilogue/threadblock/testbed.h @@ -42,6 +42,7 @@ #include "cutlass/half.h" #include "cutlass/complex.h" #include "cutlass/quaternion.h" +#include "cutlass/platform/platform.h" #include "cutlass/epilogue/thread/linear_combination.h" #include "cutlass/util/host_tensor.h" @@ -193,15 +194,15 @@ class EpilogueTestbed { cutlass::reference::host::TensorFillRandomUniform( accumulator_tensor.host_view(), seed, - 20, - -20, + 2, + -2, 0); cutlass::reference::host::TensorFillRandomUniform( source_tensor.host_view(), seed + 2018, - 20, - -20, + 2, + -2, 0); } @@ -300,7 +301,9 @@ class EpilogueTestbed { output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + output_params.beta * ElementCompute(source_tensor.at(coord)); - if (std::numeric_limits::is_integer + if ((cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || std::numeric_limits::is_integer) && !std::numeric_limits::is_integer) { std::fesetround(FE_TONEAREST); expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 17a7b3fc16..b0225b812e 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -1336,7 +1336,6 @@ struct TestbedImpl { { using namespace cute; auto problem_shape_MNKL = cute::append<4>(problem_size, 1); - auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); auto epilogue_params = collective_epilogue.to_host_args(problem_size); diff --git a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu index 9c07e72dc7..56b85846de 100644 --- a/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_f32_f32_f32_tensor_op_f32.cu @@ -163,6 +163,50 @@ TEST(SM90_Device_Gemm_f32t_f32t_f32n_tensor_op_gmma_f32, 128x128x32_1x1x1_cooper EXPECT_TRUE(test::gemm::device::TestAll()); } +TEST(SM90_Device_Gemm_f32t_f32t_f32n_tensor_op_gmma_f32, 128x128x32_1x1x1_cooperative_narrow_wgmma) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::TmaWarpSpecializedCooperative + >::CollectiveOp; + + // Manually configure a half-tile wide MMA instruction + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized<5, Shape<_1,_1,_1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative>, + Shape<_128,_128,_32>, + float, + cutlass::detail::TagToStrideA_t, + float, + cutlass::detail::TagToStrideB_t, + decltype(cute::make_tiled_mma(cute::SM90_64x64x8_F32TF32TF32_SS_TN{}, Layout>{})), + cute::SM90_TMA_LOAD, + cute::GMMA::Layout_K_SW128_Atom, + void, + cute::identity, + cute::SM90_TMA_LOAD, + cute::GMMA::Layout_K_SW128_Atom, + void, + cute::identity + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + /////////////////////////////////////////////////////////////////////////////// #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu index 743266ac9f..9cf8f3126b 100644 --- a/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f8_tensor_op_fp32.cu @@ -54,6 +54,7 @@ #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) using namespace cute; + /////////////////////////////////////////////////////////////////////////////// //////////////////////////////// output: E4M3 ///////////////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -760,7 +761,8 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_e5m2n_tensor_op_gmma_f32, 64x128x128_2x4x1_non EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } - +// Use Hopper FP8+AUX from 12.1 +#if (!((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ == 0))) /////////////////////////////////////////////////////////////////////////////// ///////////////////////// output: E4M3 + Aux Tensor /////////////////////////// @@ -808,6 +810,7 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tenso using Gemm = cutlass::gemm::device::GemmUniversalAdapter; EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } +#endif /////////////////////////////////////////////////////////////////////////////// ////////////////////////////////// FP8 Accum ///////////////////////////////// @@ -990,6 +993,10 @@ TEST(SM90_Device_Gemm_e4m3t_e4m3n_e4m3n_tensor_op_gmma_f32, 64x128x128_bias_bf16 EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } + +// Use Hopper FP8+AUX from 12.1 +#if (!((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ == 0))) + /////////////////////////////////////////////////////////////////////////////// ///////////////////// output: E4M3 + Aux Tensor + Bias///////////////////////// /////////////////////////////////////////////////////////////////////////////// @@ -1142,6 +1149,8 @@ TEST(SM90_Device_Gemm_e4m3t_e5m2n_e4m3n_tensor_op_gmma_f32, 64x128x128_aux_tenso EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); } +#endif + /////////////////////////////////////////////////////////////////////////////// //////////////////////////////// TMA epilogue ///////////////////////////////// /////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/nvrtc/thread/testbed.h b/test/unit/nvrtc/thread/testbed.h index 7b0b123625..6c59afeb76 100644 --- a/test/unit/nvrtc/thread/testbed.h +++ b/test/unit/nvrtc/thread/testbed.h @@ -275,7 +275,7 @@ struct Testbed { nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); const char *opts[] = {"--gpu-architecture=compute_75", - "--std=c++11", + "--std=c++17", "--include-path=/usr/local/cuda-10.1/include"}; result_nvrtc = nvrtcCompileProgram(program, 3, opts); diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h index f327dbaeda..ee7b65fe10 100644 --- a/tools/library/include/cutlass/library/operation_table.h +++ b/tools/library/include/cutlass/library/operation_table.h @@ -109,7 +109,7 @@ struct GemmFunctionalKey { inline bool operator==(GemmFunctionalKey const &rhs) const { - return + return (provider == rhs.provider) && (gemm_kind == rhs.gemm_kind) && (element_compute == rhs.element_compute) && @@ -165,7 +165,7 @@ struct GemmFunctionalKeyHasher { inline static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8 - shl)); + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); } inline @@ -173,8 +173,8 @@ struct GemmFunctionalKeyHasher { IntHash hash; return - rotl(hash(int(key.provider)), 1) ^ - rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ rotl(hash(int(key.element_compute)), 3) ^ rotl(hash(int(key.element_scalar)), 4) ^ rotl(hash(int(key.element_A)), 5) ^ @@ -207,7 +207,7 @@ struct GemmPreferenceKey { GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } bool operator<(GemmPreferenceKey const &rhs) const { - return (compute_capability < rhs.compute_capability) || + return (compute_capability < rhs.compute_capability) || ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); } @@ -288,9 +288,9 @@ struct ConvFunctionalKey { layout_C(layout_C), element_accumulator(element_accumulator), element_compute(element_compute) - { } + { } - inline + inline bool operator==(ConvFunctionalKey const &rhs) const { return (provider == rhs.provider) && @@ -305,7 +305,7 @@ struct ConvFunctionalKey { (element_compute == rhs.element_compute); } - inline + inline bool operator!=(ConvFunctionalKey const &rhs) const { return !(*this == rhs); } @@ -325,7 +325,7 @@ std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctio << "element_accumulator: " << to_string(key.element_accumulator) << std::endl << "element_compute: " << to_string(key.element_compute) << std::endl << "}"; - + return out; } @@ -335,14 +335,14 @@ struct ConvFunctionalKeyHasher { inline static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8 - shl)); + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); } inline size_t operator()(ConvFunctionalKey const &key) const { IntHash hash; - return + return rotl(hash(int(key.provider)), 1) ^ rotl(hash(int(key.conv_kind)), 2) ^ rotl(hash(int(key.element_A)), 3) ^ @@ -370,11 +370,11 @@ struct ConvPreferenceKey { ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } - ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): + ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): compute_capability(cc), iterator_algorithm(iterator_algorithm) { } bool operator<(ConvPreferenceKey const &rhs) const { - return (compute_capability < rhs.compute_capability) || + return (compute_capability < rhs.compute_capability) || ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); } @@ -433,9 +433,9 @@ struct ReductionFunctionalKey { element_compute(element_compute), reduce_math_op(reduce_math_op), epilogue_math_op(epilogue_math_op) - { } + { } - inline + inline bool operator==(ReductionFunctionalKey const &rhs) const { return (provider == rhs.provider) && @@ -447,7 +447,7 @@ struct ReductionFunctionalKey { (epilogue_math_op == rhs.epilogue_math_op); } - inline + inline bool operator!=(ReductionFunctionalKey const &rhs) const { return !(*this == rhs); } @@ -459,14 +459,14 @@ struct ReductionFunctionalKeyHasher { inline static size_t rotl(size_t key, int shl) { - return (key << shl) | (key >> (sizeof(key)*8 - shl)); + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); } inline size_t operator()(ReductionFunctionalKey const &key) const { IntHash hash; - return + return rotl(hash(int(key.provider)), 1) ^ rotl(hash(int(key.element_workspace)), 2) ^ rotl(hash(int(key.element_accumulator)), 3) ^ @@ -505,19 +505,19 @@ using ReductionOperationFunctionalMap = std::unordered_map< class OperationTable { public: - /// Map of all operations of type kGemm + /// Map of all operations of type kGemm // provider (kCUTLASS) GemmOperationFunctionalMap gemm_operations; - /// Map of all operations of type kConv2d + /// Map of all operations of type kConv2d // provider (kCUTLASS, kReferenceHost, kReferenceDevice) ConvOperationFunctionalMap conv2d_operations; - /// Map of all operations of type kConv3d + /// Map of all operations of type kConv3d // provider (kCUTLASS, kReferenceHost, kReferenceDevice) ConvOperationFunctionalMap conv3d_operations; - /// Map of all operations of type kConv2d + /// Map of all operations of type kConv2d // provider (kCUTLASS) ReductionOperationFunctionalMap reduction_operations; diff --git a/tools/library/src/gemm_operation_3x.hpp b/tools/library/src/gemm_operation_3x.hpp index b0c241e7fd..e50f3a1bc8 100644 --- a/tools/library/src/gemm_operation_3x.hpp +++ b/tools/library/src/gemm_operation_3x.hpp @@ -38,6 +38,7 @@ #include "cutlass/library/library.h" #include "library_internal.h" #include "cutlass/gemm/dispatch_policy.hpp" +#include /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -271,7 +272,6 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { /// Returns success if the operation can proceed Status can_implement( void const *configuration_ptr, void const *arguments_ptr) const override { - GemmUniversalConfiguration const *configuration = static_cast(configuration_ptr); GemmUniversalArguments const *arguments = @@ -289,7 +289,6 @@ class GemmUniversal3xOperation : public GemmOperation3xBase { configuration->problem_size.n(), configuration->problem_size.k(), configuration->batch_count); - return Operator::can_implement(args); } diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index cd5887c324..2b57dbc317 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -152,6 +152,7 @@ template <> struct NumericTypeMap { static NumericTypeID const kId = NumericTypeID::kTF32; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// template struct MathOperationMap { diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index 0b37f6d58f..927806d2b1 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -422,6 +422,8 @@ Status from_string(std::string const &str) { /////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + static struct { char const *text; char const *pretty; diff --git a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h index 90e86218be..7408701801 100644 --- a/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h +++ b/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -238,7 +238,9 @@ class GemmOperationProfiler : public OperationProfiler { DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem); + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); /// Method to profile a CUTLASS Operation Status profile_cutlass_( diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index 0be2b9fe52..daee075656 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -746,7 +746,13 @@ bool GemmOperationProfiler::verify_cutlass( } #endif // #if CUTLASS_ENABLE_CUBLAS - bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem); + library::GemmDescription const &gemm_desc = + static_cast(operation->description()); + + + cutlass::library::NumericTypeID element_A = gemm_desc.A.element; + cutlass::library::NumericTypeID element_B = gemm_desc.B.element; + bool verification_status = verify_with_reference_(options, report, device_context, operation, problem_space, problem, element_A, element_B); // Update disposition to worst case verification outcome among all // verification providers which are supported @@ -912,8 +918,10 @@ bool GemmOperationProfiler::verify_with_reference_( DeviceContext &device_context, library::Operation const *operation, ProblemSpace const &problem_space, - ProblemSpace::Problem const &problem) { - + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B) +{ library::GemmDescription const &gemm_desc = static_cast(operation->description()); @@ -976,13 +984,13 @@ bool GemmOperationProfiler::verify_with_reference_( problem_.alpha.data(), - gemm_desc.A.element, + element_A, gemm_desc.A.layout, gemm_desc.transform_A, ptr_A, int(gemm_workspace_.configuration.lda), - gemm_desc.B.element, + element_B, gemm_desc.B.layout, gemm_desc.transform_B, ptr_B, @@ -1010,7 +1018,6 @@ bool GemmOperationProfiler::verify_with_reference_( results_.back().verification_map[provider] = Disposition::kNotRun; continue; } - results_.back().status = status; if (provider == library::Provider::kReferenceHost) { diff --git a/tools/util/include/cutlass/util/print_error.hpp b/tools/util/include/cutlass/util/print_error.hpp index aeeda92d14..9eed9d1438 100644 --- a/tools/util/include/cutlass/util/print_error.hpp +++ b/tools/util/include/cutlass/util/print_error.hpp @@ -62,7 +62,6 @@ template matrix_inf_norm_result matrix_inf_norm(cute::Tensor const& host_matrix) { - using std::abs; using error_type = decltype(std::declval().inf_norm); using element_type = typename EngineType::value_type; @@ -74,14 +73,25 @@ matrix_inf_norm(cute::Tensor const& host_matrix) const int64_t num_rows = cute::size<0>(host_matrix); const int64_t num_cols = cute::size<1>(host_matrix); - for(int64_t i = 0; i < num_rows; ++i) { + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + for (int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; for(int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += abs(host_matrix(i, j)); + row_abs_sum += abs_fn(host_matrix(i, j)); } - if(std::isnan(row_abs_sum)) { + if (std::isnan(row_abs_sum)) { found_nan = true; - } else { + } + else { inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; } } @@ -95,10 +105,19 @@ matrix_inf_norm_result matrix_diff_inf_norm(cute::Tensor const& X, cute::Tensor const& Y) { - using std::abs; using error_type = decltype(std::declval().inf_norm); using element_type = typename EngineType::value_type; + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + assert(cute::size<0>(X) == cute::size<0>(Y)); assert(cute::size<1>(X) == cute::size<1>(Y)); @@ -110,15 +129,16 @@ matrix_diff_inf_norm(cute::Tensor const& X, error_type inf_norm = 0.0; bool found_nan = false; - for(int64_t i = 0; i < num_rows; ++i) { + for (int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; - for(int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += error_type(abs(element_type(X(i,j)) - - element_type(Y(i,j)))); + for (int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - + element_type(Y(i,j)))); } - if(std::isnan(row_abs_sum)) { + if (std::isnan(row_abs_sum)) { found_nan = true; - } else { + } + else { inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; } } @@ -130,7 +150,7 @@ template -void +auto print_matrix_multiply_mollified_relative_error( char const A_value_type_name[], cute::Tensor const& A, @@ -158,13 +178,13 @@ print_matrix_multiply_mollified_relative_error( using std::cout; using cute::shape; cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' - << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' - << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' - << std::scientific - << "Infinity norm of A: " << A_norm << '\n' - << "Infinity norm of B: " << B_norm << '\n' - << "Infinity norm of C: " << C_norm << '\n' - << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; if(A_norm_times_B_norm == 0.0) { cout << "Mollified relative error: " << relative_error << '\n'; @@ -173,15 +193,16 @@ print_matrix_multiply_mollified_relative_error( } if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { - cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' - << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; } + return relative_error; } template -void +auto print_matrix_multiply_mollified_relative_error( const char value_type_name[], const cute::Tensor& A, @@ -189,7 +210,7 @@ print_matrix_multiply_mollified_relative_error( const cute::Tensor& C_computed, const cute::Tensor& C_expected) { - print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, value_type_name, C_computed, C_expected); } @@ -314,7 +335,7 @@ print_relative_error( bool print_error = true, double error_margin = 0.00001) { assert(size(data) == size(reference)); - return print_relative_error(static_cast(size(data)), - data, reference, + return print_relative_error(static_cast(size(data)), + data, reference, print_verbose, print_error, error_margin); } diff --git a/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/tools/util/include/cutlass/util/reference/device/tensor_fill.h index 93e559e144..05b877a235 100644 --- a/tools/util/include/cutlass/util/reference/device/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -1713,7 +1713,7 @@ void BlockFillSequential( Layout layout = Layout::packed(size); TensorView view(ptr, layout, size); - Array c; + Array c{}; c[0] = v; TensorFillLinear(view, c, s); diff --git a/tools/util/include/cutlass/util/reference/host/conv.hpp b/tools/util/include/cutlass/util/reference/host/conv.hpp index cbca2df631..202091d95e 100644 --- a/tools/util/include/cutlass/util/reference/host/conv.hpp +++ b/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -41,6 +41,8 @@ #include "cute/tensor.hpp" +#include + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::reference::host { @@ -93,7 +95,8 @@ template< class TensorAlpha_, class TensorBeta_, class TensorBias_, - class ActivationFunctor_ = cutlass::epilogue::thread::Identity> + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> struct ConvEpilogueFusionParams { using ElementAcc = ElementAcc_; using ElementScalar = ElementScalar_; @@ -104,7 +107,6 @@ struct ConvEpilogueFusionParams { using TensorBeta = TensorBeta_; using TensorBias = TensorBias_; using ActivationFunctor = ActivationFunctor_; - ElementScalar alpha = ElementScalar(1); ElementScalar beta = ElementScalar(0); @@ -155,6 +157,7 @@ struct ConvReferenceImpl { // Epilogue activation operation ActivationFunctor epi_activation; + ConvReferenceImpl( TensorA const& tensor_a, TensorB const& tensor_b, @@ -201,7 +204,7 @@ struct ConvReferenceImpl { #pragma omp parallel for collapse(2) #endif for (int32_t n = 0; n < N; ++n) { - for (int32_t q = 0; q < Q; ++q) { + for (int32_t q = 0; q < Q; ++q) { for (int32_t k = 0; k < K; ++k) { auto accumulator = ElementAcc(0); for (int32_t s = 0; s < S; ++s) { @@ -226,6 +229,7 @@ struct ConvReferenceImpl { } } } + } // Specialization for 2D fprop kernel @@ -272,6 +276,7 @@ struct ConvReferenceImpl { } } } + } // Specialization for 3D fprop kernel @@ -325,6 +330,7 @@ struct ConvReferenceImpl { } } } + } // Specialization for 1D dgrad kernel @@ -371,6 +377,7 @@ struct ConvReferenceImpl { } } } + } // Specialization for 2D dgrad kernel @@ -424,11 +431,14 @@ struct ConvReferenceImpl { if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { output += bias_converter(epi_fusion_params_.tensor_bias[c]); } + output = epi_activation(output); + tensor_d_(c, w, h, n) = output_converter(output); } } } } + } // Specialization for 3D dgrad kernel @@ -501,6 +511,7 @@ struct ConvReferenceImpl { } } } + } // Specialization for 1D wgrad kernel diff --git a/tools/util/include/cutlass/util/reference/host/convolution.h b/tools/util/include/cutlass/util/reference/host/convolution.h index 07d3681f12..f28b4a658a 100644 --- a/tools/util/include/cutlass/util/reference/host/convolution.h +++ b/tools/util/include/cutlass/util/reference/host/convolution.h @@ -197,7 +197,7 @@ void Depsep_Fprop(cutlass::TensorView tensor_A, } //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Dgrad +/// Dgrad / Deconv //////////////////////////////////////////////////////////////////////////////////////////////////// /// dx = dgrad(dy, w) @@ -221,7 +221,8 @@ void Conv2dDgrad( TensorRef tensor_dx_in, TensorRef tensor_dx_out, ElementCompute alpha, - ElementCompute beta) { + ElementCompute beta, + bool is_deconv = false) { ConvertOp convert_op; InnerProductOp inner_product_op; @@ -272,7 +273,8 @@ void Conv2dDgrad( if (p < problem_size.P && q < problem_size.Q) { ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); - ElementB b = tensor_w.at(cutlass::make_Coord(k, r, s, c)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); } @@ -420,6 +422,7 @@ void Conv2d( >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; + case conv::Operator::kDeconv: case conv::Operator::kDgrad: Conv2dDgrad< ElementA, LayoutA, @@ -429,7 +432,7 @@ void Conv2d( ElementAccumulator, ElementD, ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); break; case conv::Operator::kWgrad: @@ -537,7 +540,7 @@ void Conv3dFprop( } //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Dgrad +/// Dgrad / Deconv //////////////////////////////////////////////////////////////////////////////////////////////////// /// dx = dgrad(dy, w) @@ -560,7 +563,8 @@ void Conv3dDgrad( TensorRef tensor_dx_in, TensorRef tensor_dx_out, ElementCompute alpha, - ElementCompute beta) { + ElementCompute beta, + bool is_deconv = false) { ConvertOp convert_op; InnerProductOp inner_product_op; @@ -604,8 +608,8 @@ void Conv3dDgrad( if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); - ElementB b = tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); - + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); } } @@ -760,6 +764,7 @@ void Conv3d( >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); break; + case conv::Operator::kDeconv: case conv::Operator::kDgrad: Conv3dDgrad< ElementA, LayoutA, @@ -768,7 +773,7 @@ void Conv3d( ElementCompute, ElementAccumulator, ConvertOp, InnerProductOp - >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); break; case conv::Operator::kWgrad: diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 978b666c34..84aa93634e 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -35,10 +35,11 @@ #pragma once ///////////////////////////////////////////////////////////////////////////////////////////////// - +#include "cutlass/gemm/gemm.h" #include "cutlass/complex.h" #include "cutlass/numeric_conversion.h" #include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" #include "cute/tensor.hpp" @@ -115,7 +116,6 @@ struct GettEpilogueParams { using LayoutC = typename TensorC::layout_type; using EngineD = typename TensorD::engine_type; using LayoutD = typename TensorD::layout_type; - static constexpr bool PerColumnBias = PerColumnBias_; ElementScalar alpha = ElementScalar(1);