Skip to content

Commit

Permalink
Updates for CUTLASS 3.5.0 (#1468)
Browse files Browse the repository at this point in the history
  • Loading branch information
thakkarV authored Apr 12, 2024
1 parent a40e08e commit 7d49e6c
Show file tree
Hide file tree
Showing 171 changed files with 7,429 additions and 1,791 deletions.
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# NVIDIA CUTLASS Changelog

## 3.5 (2024-03-18)
## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09)

- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
Expand All @@ -12,8 +12,13 @@
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
+ [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934).
+ [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564).
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
- Fixes to greatly reduce build warnings.
- Updates and bugfixes from the community (thanks!)

## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14)
Expand Down
15 changes: 14 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
endif()

message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++17 if set")

# To reduce duplicate version locations, parse the version out of the
# main versions.h file and reuse it here.
Expand Down Expand Up @@ -332,6 +332,18 @@ endif()



# Warnings-as-error exceptions and warning suppressions for Clang builds
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=implicit-int-conversion ")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=implicit-int-conversion" )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=pass-failed ")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=pass-failed" )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=inconsistent-missing-override ")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-error=inconsistent-missing-override" )
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-sign-conversion ")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS "-Wno-sign-conversion" )
endif()

if (NOT MSVC AND CUTLASS_NVCC_KEEP)
# MSVC flow handles caching already, but for other generators we handle it here.
set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files")
Expand All @@ -357,6 +369,7 @@ if (CUTLASS_ENABLE_OPENMP_TESTS)
message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.")
endif()
endif()

if(UNIX)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing)
Expand Down
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# CUTLASS 3.5

_CUTLASS 3.5 - March 2024_
_CUTLASS 3.5 - April 2024_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
Expand Down Expand Up @@ -45,18 +45,21 @@ In addition to GEMMs, CUTLASS implements high-performance convolution via the im

CUTLASS 3.5 is an update to CUTLASS adding:

- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp)
- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp).
+ Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md).
+ Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp).
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms
+ Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms.
+ [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API.
+ NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design!
- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer.
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x
- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_gemm/README.md) in CuTe and CUTLASS 3.x.
+ Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs.
+ Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores.
- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices.
- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial).
- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337).
- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17.
- Fixes to greatly reduce build warnings.
- Updates and bugfixes from the community (thanks!)

Minimum requirements:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ constexpr int NumStages = 3;
// Which iterator algorithm to use: Analytic or Optimized
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized;

// Is the output packed or strided
// Use kStride if using strided output
static cutlass::conv::StrideSupport const OutputStride = cutlass::conv::StrideSupport::kUnity;

// The epilogue part of the kernel
using EpilogueOp = cutlass::epilogue::thread::LinearCombination<
ElementOutput, // Data type of output matrix.
Expand All @@ -289,7 +293,8 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm
IteratorAlgorithm,
OutputStride
>::Kernel;

// Type of the actual kernel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ implicitly to tf32 inside the GEMM kernel which means no change is needed to acc
fp32 data by using NVIDIA Ampere architecture.
We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h).
The trick is very simple
a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big
Expand All @@ -45,11 +45,11 @@ The trick is very simple
a_small x b_small is discarded because they are too small.
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32
This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32
results (SGEMM using SIMT) and against FP64 results (DGEMM)
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
OpMultiplyAddFastF32.
To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to
OpMultiplyAddFastF32.
Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference
Expand Down Expand Up @@ -97,14 +97,14 @@ struct Result {
double l2_norm_fp32_vs_fp64;

// ctor
Result(
Result(
int m, int n, int k,
double runtime_ms, double gflops,
double l2_norm_3xtf32_vs_fp64,
double l2_norm_1xtf32_vs_fp64,
double l2_norm_fp32_vs_fp64) :
double l2_norm_fp32_vs_fp64) :
m(m), n(n), k(k),
runtime_ms(runtime_ms), gflops(gflops),
runtime_ms(runtime_ms), gflops(gflops),
l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64),
l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64),
l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {}
Expand Down Expand Up @@ -147,7 +147,7 @@ struct Options {
int iterations;
int seed;
bool benchmark;

Options():
help(false),
problem_size({3456, 4096, 4096}),
Expand Down Expand Up @@ -190,7 +190,7 @@ struct Options {

cmd.get_cmd_line_argument("alpha", alpha);
cmd.get_cmd_line_argument("beta", beta);

cmd.get_cmd_line_argument("iterations", iterations);
cmd.get_cmd_line_argument("seed", seed);
cmd.get_cmd_line_argument("rand_mode", rand_mode);
Expand Down Expand Up @@ -227,9 +227,9 @@ struct Options {
/// Compute performance in GFLOP/s
double gflops(double runtime_s) const {

// Number of real-valued multiply-adds
// Number of real-valued multiply-adds
int64_t fmas = problem_size.product();

// Two flops per multiply-add
return 2.0 * double(fmas) / double(1.0e9) / runtime_s;
}
Expand Down Expand Up @@ -272,10 +272,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination<

// Number of pipelines you want to use
constexpr int NumStages = 3;
// Alignment
// Alignment
constexpr int Alignment = 4;

//
//
// Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64)
//

Expand All @@ -296,7 +296,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::Gemm<
EpilogueOp,
SwizzleThreadBlock,
NumStages,
Alignment,
Alignment,
Alignment,
false,
cutlass::arch::OpMultiplyAddFastF32>;
Expand All @@ -318,7 +318,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::Gemm<
EpilogueOp,
SwizzleThreadBlock,
NumStages,
Alignment,
Alignment,
Alignment,
false,
cutlass::arch::OpMultiplyAdd>;
Expand Down Expand Up @@ -356,7 +356,7 @@ bool run(Options &options) {
cutlass::HostTensor<float, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<float, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<float, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N
cutlass::HostTensor<float, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N

if (options.rand_mode == "uniform") {
const float min = -1;
Expand Down Expand Up @@ -397,7 +397,7 @@ bool run(Options &options) {
}
cutlass::reference::host::TensorFill(
tensor_d_F32.host_view()); // <- fill matrix D on host with zeros

// Copy data from host to GPU
tensor_a_F32.sync_device();
tensor_b_F32.sync_device();
Expand All @@ -411,7 +411,7 @@ bool run(Options &options) {
cutlass::HostTensor<double, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K
cutlass::HostTensor<double, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N
cutlass::HostTensor<double, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N

// Gemm output (D) for GEMM_F64
cutlass::HostTensor<double, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N
// Gemm output (D) for GEMM_3xTF32
Expand All @@ -426,7 +426,7 @@ bool run(Options &options) {
cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view());
cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view());
cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view());

// Copy data from host to GPU
tensor_a_F64.sync_device();
tensor_b_F64.sync_device();
Expand Down Expand Up @@ -464,7 +464,7 @@ bool run(Options &options) {
// Instantiate CUTLASS kernel depending on templates
Gemm_3xTF32 gemm_op_3xTF32;

// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32);
CUTLASS_CHECK(status_3xtf32);

Expand Down Expand Up @@ -568,7 +568,7 @@ bool run(Options &options) {
// Instantiate CUTLASS kernel depending on templates
Gemm_1xTF32 gemm_op_1xtf32;

// Check the problem size is supported or not
// Check the problem size is supported or not
cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32);
CUTLASS_CHECK(status_1xtf32);

Expand Down Expand Up @@ -627,7 +627,7 @@ bool run(Options &options) {
tensor_d_F32.sync_host();

////////////////////////////////////////////////////////////////////////////////
/////// Compute l2 norms
/////// Compute l2 norms
////////////////////////////////////////////////////////////////////////////////

// l2 norm 3xTF32 vs F64
Expand Down Expand Up @@ -664,7 +664,7 @@ bool run(Options &options) {
std::cout << "GFLOPs: " << result.gflops << std::endl;
std::cout << "Normalized L2 norm of" << std::endl;
std::cout.precision(8);
std::cout << std::scientific
std::cout << std::scientific
<< " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl
<< " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl
<< " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl;
Expand All @@ -673,11 +673,11 @@ bool run(Options &options) {
}

int main(int argc, const char **argv) {

bool notSupported = false;

// Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available
// in CUDA 11.0.
// in CUDA 11.0.
//
// CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples.
if (!(__CUDACC_VER_MAJOR__ >= 11)) {
Expand All @@ -690,7 +690,7 @@ int main(int argc, const char **argv) {
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (error != cudaSuccess) {
std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl;
return false;
return -1;
}

if (!((props.major * 10 + props.minor) >= 80)) {
Expand All @@ -716,17 +716,17 @@ int main(int argc, const char **argv) {

if (options.benchmark) {
for (int k = 4; k <= 65536; k *= 2) {

options.problem_size[2] = k;

printf("Gemm problem size: %d x %d x %d\n", \
options.problem_size.m(), options.problem_size.n(), options.problem_size.k());

if (!options.valid()) {
std::cerr << "Invalid problem." << std::endl;
return -1;
}

result &= run(options);
}
} else {
Expand Down
Loading

0 comments on commit 7d49e6c

Please sign in to comment.