Skip to content

Commit

Permalink
Updates to fused epilogue (#383)
Browse files Browse the repository at this point in the history
* Enhancements and fixes to fused GEMM and Convolution epilogue.
* Need to explicitly list cudart as unit test library dependency.
  • Loading branch information
kerrmudgeon authored Dec 17, 2021
1 parent 4e666e1 commit ec4f7e5
Show file tree
Hide file tree
Showing 24 changed files with 371 additions and 192 deletions.
12 changes: 11 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,19 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})
list(APPEND CUTLASS_CUDA_CLANG_FLAGS -DCUTLASS_TEST_LEVEL=${CUTLASS_TEST_LEVEL})

if (CUTLASS_TEST_ENABLE_CACHED_RESULTS)
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
message(STATUS "Enable caching of reference results in conv unit tests")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1)
endif()


set(CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED ON CACHE BOOL "Enable/Disable rigorous conv problem sizes in conv unit tests")

if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED)
message(STATUS "Enable rigorous conv problem sizes in conv unit tests")
list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1)
endif()


#
# CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations.
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ int main() {
return -1;
}

if (!((props.major * 10 + props.minor) >= 80)) {
if (props.major * 10 + props.minor < 80) {
std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80."
<< std::endl;
notSupported = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ int run() {
tensor_b.device_ref(),
tensor_c_bias.device_ref(),
tensor_ref_d.device_ref(),
alpha, 0
alpha, ElementComputeEpilogue(0)
);

// Wait for kernels to finish
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ function(cutlass_example_add_executable NAME)
PRIVATE
CUTLASS
cutlass_tools_util_includes
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
)

target_include_directories(
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/arch/wmma_sm75.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ struct Wmma<
FragmentB const &B,
FragmentC const &C) const {
nvcuda::wmma::mma_sync(D, A, B, C);

}

#else
Expand Down Expand Up @@ -186,7 +187,6 @@ struct Wmma<
FragmentA const &A,
FragmentB const &B,
FragmentC const &C) const {

nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
}
Expand Down
6 changes: 6 additions & 0 deletions include/cutlass/cutlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,12 @@ static char const* cutlassGetStatusString(cutlass::Status status) {

////////////////////////////////////////////////////////////////////////////////////////////////////


#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED
#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0
#endif


// CUDA 10.1 introduces the mma instruction
#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA)
#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0
Expand Down
2 changes: 2 additions & 0 deletions include/cutlass/epilogue/thread/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct Identity {
/// ReLu operator - propagates NaNs
template <typename T>
struct ReLu {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
T operator()(T const & threshold, T value) const {
if (value < threshold) {
Expand All @@ -76,6 +77,7 @@ struct ReLu {

template <typename T, int N>
struct ReLu<Array<T, N>> {
static const bool kIsHeavy=false;
CUTLASS_HOST_DEVICE
Array<T, N> operator()(T const & threshold, Array<T, N> const &frag) const {
Array<T, N> result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ class LinearCombinationBiasElementwise {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kElementsPerAccess; ++i) {
ElementCompute z = binary_op(alpha_ * tmp_Accum[i] + beta_ * tmp_C[i], V[i]);
result_Z[i] = z;
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
result_T[i] = z;
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
}

NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
Expand Down Expand Up @@ -230,8 +230,8 @@ class LinearCombinationBiasElementwise {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kElementsPerAccess; ++i) {
ElementCompute z = binary_op(alpha_ * tmp_Accum[i], V[i]);
result_Z[i] = z;
result_T[i] = skip_elementwise_ ? z : elementwise_op(z);
result_T[i] = z;
result_Z[i] = skip_elementwise_ ? z : elementwise_op(z);
}

NumericArrayConverter<ElementZ, ElementCompute, kElementsPerAccess> convert_z;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ class EpilogueWithBroadcast :
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("BroadcastDetail {\n");
printf(
" kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
Expand All @@ -321,6 +322,7 @@ class EpilogueWithBroadcast :
StorageShape::kCount
);
printf("};\n");
#endif
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ class EpilogueWithReduction :
/// Debug printing
CUTLASS_DEVICE
static void print() {
#if 0
printf("ReductionDetail {\n");
printf(
" kElementsPerAccess:%d\nkColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n"
Expand All @@ -228,6 +229,7 @@ class EpilogueWithReduction :
StorageShape::kCount
);
printf("};\n");
#endif
}
};

Expand Down
9 changes: 7 additions & 2 deletions include/cutlass/tfloat32.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,13 @@ tfloat32_t operator+(tfloat32_t const& lhs, tfloat32_t const& rhs) {

CUTLASS_HOST_DEVICE
tfloat32_t operator-(tfloat32_t const& lhs) {
float x = -reinterpret_cast<float const &>(lhs);
return *reinterpret_cast<tfloat32_t *>(&x);
union u_tff32 {
float val_f32;
tfloat32_t val_tf;
CUTLASS_HOST_DEVICE u_tff32() : val_f32(0) { }
};
union u_tff32 x; x.val_f32 = -reinterpret_cast<float const &>(lhs);
return x.val_tf;
}

CUTLASS_HOST_DEVICE
Expand Down
1 change: 1 addition & 0 deletions test/unit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ target_link_libraries(
cutlass_tools_util_includes
$<$<BOOL:${CUTLASS_ENABLE_CUBLAS}>:nvidia::cublas>
gtest
cudart
)

cutlass_add_library(
Expand Down
9 changes: 9 additions & 0 deletions test/unit/common/cutlass_unit_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@
#pragma nv_diag_warning boolean_controlling_expr_is_constant
#pragma warning( disable : 4503)

#include <cstdlib>
#include <string>
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Sets flags for Unit test
void FilterArchitecture();

/////////////////////////////////////////////////////////////////////////////////////////////////

/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order
// of problem sizes run by CUTLASS unit tests
int CutlassUnitTestProblemCount();

/////////////////////////////////////////////////////////////////////////////////////////////////


// active test macro
#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \
TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__
Expand Down
11 changes: 11 additions & 0 deletions test/unit/common/filter_architecture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,14 @@ void FilterArchitecture() {
}

/////////////////////////////////////////////////////////////////////////////////////////////////

int CutlassUnitTestProblemCount() {
if(const char* problem_count = std::getenv("CUTLASS_UNIT_TEST_PROBLEM_COUNT")) {

return std::stoi(problem_count);
}

return 0;
}

/////////////////////////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

#include "conv2d_testbed.h"


////////////////////////////////////////////////////////////////////////////////
TEST(SM50_Device_Conv2d_Dgrad_Analytic_ImplicitGemm_cf32nhwc_cf32nhwc_cf32nhwc_simt_f32,
64x64_8x2_32x64x8) {
Expand Down
2 changes: 1 addition & 1 deletion test/unit/conv/device/conv2d_fprop_with_broadcast_sm75.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ TEST(SM75_Device_Conv2d_Fprop_With_Broadcast_Analytic_ImplicitGemm_f16nhwc_f16nh
cutlass::half_t,
cutlass::half_t,
8,
cutlass::epilogue::thread::GELU_taylor<float>
cutlass::epilogue::thread::ReLu<float>
>;

/// Device-level Conv2d instance
Expand Down
2 changes: 0 additions & 2 deletions test/unit/conv/device/conv2d_problems.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@
#include "cutlass/conv/convolution.h"
#include "cutlass/conv/conv2d_problem_size.h"

#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 1

namespace test {
namespace conv {
namespace device {
Expand Down
Loading

0 comments on commit ec4f7e5

Please sign in to comment.