diff --git a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu index b958bd8d..0d617b06 100644 --- a/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu +++ b/src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu @@ -1,5 +1,6 @@ #include "../../utilities/cuda/cublas_context.hh" #include "cublas_kernel.hh" +#include #include #include #include @@ -8,26 +9,24 @@ namespace refactor::kernel { using namespace runtime; using namespace cublas; - template __device__ __forceinline__ static int8_t sub(T, T); - template<> __device__ __forceinline__ int8_t sub(int8_t a, int8_t b) { return a - b; } - template<> __device__ __forceinline__ int8_t sub(uint8_t a, uint8_t b) { - constexpr static int16_t MAX = 127; - return static_cast(CUB_MIN(MAX, static_cast(a) - static_cast(b))); + template __device__ __forceinline__ static float sub(T a, T b) { + return static_cast(a) - static_cast(b); } template struct MatMulIntegerZPFunctorScalar { T const *zp; - __device__ int8_t operator()(T x) const noexcept { + __device__ float operator()(T x) const noexcept { return sub(x, *zp); } }; template static void applyZeroPointScalar( - size_t size, int8_t *dst, void const *src_, void const *zp_) { + size_t size, void *dst_, void const *src_, void const *zp_) { + auto dst = reinterpret_cast(dst_); auto src = reinterpret_cast(src_), zp = reinterpret_cast(zp_); thrust::transform(thrust::device, @@ -40,7 +39,7 @@ namespace refactor::kernel { dim_t m, n, a, b, c; T const *src, *zp; - __device__ int8_t operator()(size_t idx) const noexcept { + __device__ float operator()(size_t idx) const noexcept { auto k = idx % n, j = idx / n % m, @@ -52,7 +51,9 @@ namespace refactor::kernel { template static void applyZeroPointA( dim_t b, dim_t m, dim_t n, - int8_t *dst, void const *src_, void const *zp_) { + void *dst_, void const *src_, void const *zp_) { + + auto dst = reinterpret_cast(dst_); thrust::tabulate(thrust::device, dst, dst + b * m * n, MatMulIntegerZPFunctor{ @@ -69,8 +70,9 @@ namespace refactor::kernel { template static void applyZeroPointB( dim_t b, dim_t m, dim_t n, - int8_t *dst, void const *src_, void const *zp_) { + void *dst_, void const *src_, void const *zp_) { + auto dst = reinterpret_cast(dst_); thrust::tabulate(thrust::device, dst, dst + b * m * n, MatMulIntegerZPFunctor{ @@ -84,87 +86,91 @@ namespace refactor::kernel { }); } - struct MatMulIntegerCastFunctor { - __device__ int8_t operator()(uint8_t x) const noexcept { - return static_cast(CUB_MIN(127, x)); - } - }; - + template static void applyCast( - size_t size, int8_t *dst, void const *src_) { + void *dst_, void const *src_, size_t size) { - auto src = reinterpret_cast(src_); + auto dst = reinterpret_cast(dst_); + auto src = reinterpret_cast(src_); thrust::transform(thrust::device, src, src + size, - dst, MatMulIntegerCastFunctor{}); + dst, cub::CastOp{}); } auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace { size_t workspace = 0; - if (info.a.withZeroPoint || !info.a.signed_) { - workspace += info.batch() * info.m * info.k; - } - if (info.b.withZeroPoint || !info.b.signed_) { - workspace += info.batch() * info.k * info.n; - } + workspace += info.batch() * info.m * info.k * sizeof(float); + workspace += info.batch() * info.k * info.n * sizeof(float); + workspace += info.batch() * info.m * info.n * sizeof(float); auto routine = [info = info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) { auto workspacePtr = reinterpret_cast(workspace); - auto a = reinterpret_cast(inputs[0]), - b = reinterpret_cast(inputs[1]); - auto y = reinterpret_cast(outputs[0]); + float *a, *b, *y; - if (auto meta = info.a; meta.withZeroPoint) { + { auto size = info.batch() * info.m * info.k; + auto input = inputs[0]; auto zp = inputs[2]; - if (meta.scalar) { - if (meta.signed_) { - applyZeroPointScalar(size, workspacePtr, a, zp); + a = reinterpret_cast(workspacePtr); + workspacePtr += size * sizeof(float); + + if (auto meta = info.a; meta.withZeroPoint) { + if (meta.scalar) { + if (meta.signed_) { + applyZeroPointScalar(size, a, input, zp); + } else { + applyZeroPointScalar(size, a, input, zp); + } } else { - applyZeroPointScalar(size, workspacePtr, a, zp); + if (meta.signed_) { + applyZeroPointA(info.batch(), info.m, info.k, a, input, zp); + } else { + applyZeroPointA(info.batch(), info.m, info.k, a, input, zp); + } } } else { if (meta.signed_) { - applyZeroPointA(info.batch(), info.m, info.k, workspacePtr, a, zp); + applyCast(a, input, size); } else { - applyZeroPointA(info.batch(), info.m, info.k, workspacePtr, a, zp); + applyCast(a, input, size); } } - a = workspacePtr; - workspacePtr += size; - } else if (!meta.signed_) { - auto size = info.batch() * info.m * info.k; - applyCast(size, workspacePtr, a); - a = workspacePtr; - workspacePtr += size; } - if (auto meta = info.b; meta.withZeroPoint) { + { auto size = info.batch() * info.k * info.n; + auto input = inputs[1]; auto zp = inputs[3]; - if (meta.scalar) { - if (meta.signed_) { - applyZeroPointScalar(size, workspacePtr, b, zp); + b = reinterpret_cast(workspacePtr); + workspacePtr += size * sizeof(float); + + if (auto meta = info.b; meta.withZeroPoint) { + if (meta.scalar) { + if (meta.signed_) { + applyZeroPointScalar(size, b, input, zp); + } else { + applyZeroPointScalar(size, b, input, zp); + } } else { - applyZeroPointScalar(size, workspacePtr, b, zp); + if (meta.signed_) { + applyZeroPointA(info.batch(), info.m, info.k, b, input, zp); + } else { + applyZeroPointA(info.batch(), info.m, info.k, b, input, zp); + } } } else { if (meta.signed_) { - applyZeroPointA(info.batch(), info.k, info.n, workspacePtr, b, zp); + applyCast(b, input, size); } else { - applyZeroPointA(info.batch(), info.k, info.n, workspacePtr, b, zp); + applyCast(b, input, size); } } - b = workspacePtr; - } else if (!meta.signed_) { - auto size = info.batch() * info.k * info.n; - applyCast(size, workspacePtr, b); - b = workspacePtr; } + y = reinterpret_cast(workspacePtr); auto handle = res.fetchOrStore()->handle; - int32_t alpha = 1, - beta = 0; + float alpha = 1, + beta = 0; auto m = info.m, n = info.n, k = info.k; @@ -183,10 +189,10 @@ namespace refactor::kernel { CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, - b + strideB * offset[1], CUDA_R_8I, ldb, - a + strideA * offset[0], CUDA_R_8I, lda, - &beta, y + strideY * i, CUDA_R_32I, n, - CUDA_R_32I, CUBLAS_GEMM_DEFAULT)); + b + strideB * offset[1], CUDA_R_32F, ldb, + a + strideA * offset[0], CUDA_R_32F, lda, + &beta, y + strideY * i, CUDA_R_32F, n, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); } } else { @@ -195,12 +201,14 @@ namespace refactor::kernel { CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &alpha, - b, CUDA_R_8I, ldb, strideB, - a, CUDA_R_8I, lda, strideA, - &beta, y, CUDA_R_32I, n, + b, CUDA_R_32F, ldb, strideB, + a, CUDA_R_32F, lda, strideA, + &beta, y, CUDA_R_32F, n, strideY, info.batch(), - CUDA_R_32I, CUBLAS_GEMM_DEFAULT)); + CUDA_R_32F, CUBLAS_GEMM_DEFAULT)); } + + applyCast(outputs[0], y, info.batch() * info.m * info.n); }; res.fetchOrStore();