Skip to content

Commit

Permalink
fix(kernel): MatMulInteger 算法与 onnxruntime 对齐
Browse files Browse the repository at this point in the history
Signed-off-by: YdrMaster <[email protected]>
  • Loading branch information
YdrMaster committed Jan 31, 2024
1 parent e9c316f commit bca7336
Showing 1 changed file with 72 additions and 64 deletions.
136 changes: 72 additions & 64 deletions src/04kernel/src/kernels/mat_mul_integer/cublas_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "../../utilities/cuda/cublas_context.hh"
#include "cublas_kernel.hh"
#include <cub/cub.cuh>
#include <thrust/execution_policy.h>
#include <thrust/tabulate.h>
#include <thrust/transform.h>
Expand All @@ -8,26 +9,24 @@ namespace refactor::kernel {
using namespace runtime;
using namespace cublas;

template<class T> __device__ __forceinline__ static int8_t sub(T, T);
template<> __device__ __forceinline__ int8_t sub<int8_t>(int8_t a, int8_t b) { return a - b; }
template<> __device__ __forceinline__ int8_t sub<uint8_t>(uint8_t a, uint8_t b) {
constexpr static int16_t MAX = 127;
return static_cast<int8_t>(CUB_MIN(MAX, static_cast<int16_t>(a) - static_cast<int16_t>(b)));
template<class T> __device__ __forceinline__ static float sub(T a, T b) {
return static_cast<float>(a) - static_cast<float>(b);
}

template<class T>
struct MatMulIntegerZPFunctorScalar {
T const *zp;

__device__ int8_t operator()(T x) const noexcept {
__device__ float operator()(T x) const noexcept {
return sub(x, *zp);
}
};

template<class T>
static void applyZeroPointScalar(
size_t size, int8_t *dst, void const *src_, void const *zp_) {
size_t size, void *dst_, void const *src_, void const *zp_) {

auto dst = reinterpret_cast<float *>(dst_);
auto src = reinterpret_cast<T const *>(src_),
zp = reinterpret_cast<T const *>(zp_);
thrust::transform(thrust::device,
Expand All @@ -40,7 +39,7 @@ namespace refactor::kernel {
dim_t m, n, a, b, c;
T const *src, *zp;

__device__ int8_t operator()(size_t idx) const noexcept {
__device__ float operator()(size_t idx) const noexcept {
auto
k = idx % n,
j = idx / n % m,
Expand All @@ -52,7 +51,9 @@ namespace refactor::kernel {
template<class T>
static void applyZeroPointA(
dim_t b, dim_t m, dim_t n,
int8_t *dst, void const *src_, void const *zp_) {
void *dst_, void const *src_, void const *zp_) {

auto dst = reinterpret_cast<float *>(dst_);
thrust::tabulate(thrust::device,
dst, dst + b * m * n,
MatMulIntegerZPFunctor<T>{
Expand All @@ -69,8 +70,9 @@ namespace refactor::kernel {
template<class T>
static void applyZeroPointB(
dim_t b, dim_t m, dim_t n,
int8_t *dst, void const *src_, void const *zp_) {
void *dst_, void const *src_, void const *zp_) {

auto dst = reinterpret_cast<float *>(dst_);
thrust::tabulate(thrust::device,
dst, dst + b * m * n,
MatMulIntegerZPFunctor<T>{
Expand All @@ -84,87 +86,91 @@ namespace refactor::kernel {
});
}

struct MatMulIntegerCastFunctor {
__device__ int8_t operator()(uint8_t x) const noexcept {
return static_cast<int8_t>(CUB_MIN(127, x));
}
};

template<class To, class From>
static void applyCast(
size_t size, int8_t *dst, void const *src_) {
void *dst_, void const *src_, size_t size) {

auto src = reinterpret_cast<uint8_t const *>(src_);
auto dst = reinterpret_cast<To *>(dst_);
auto src = reinterpret_cast<From const *>(src_);
thrust::transform(thrust::device,
src, src + size,
dst, MatMulIntegerCastFunctor{});
dst, cub::CastOp<To>{});
}

auto MatMulIntegerCublas::lower(Resources &res) const noexcept -> RoutineWorkspace {

size_t workspace = 0;
if (info.a.withZeroPoint || !info.a.signed_) {
workspace += info.batch() * info.m * info.k;
}
if (info.b.withZeroPoint || !info.b.signed_) {
workspace += info.batch() * info.k * info.n;
}
workspace += info.batch() * info.m * info.k * sizeof(float);
workspace += info.batch() * info.k * info.n * sizeof(float);
workspace += info.batch() * info.m * info.n * sizeof(float);

auto routine = [info = info](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto workspacePtr = reinterpret_cast<int8_t *>(workspace);
auto a = reinterpret_cast<int8_t const *>(inputs[0]),
b = reinterpret_cast<int8_t const *>(inputs[1]);
auto y = reinterpret_cast<int32_t *>(outputs[0]);
float *a, *b, *y;

if (auto meta = info.a; meta.withZeroPoint) {
{
auto size = info.batch() * info.m * info.k;
auto input = inputs[0];
auto zp = inputs[2];
if (meta.scalar) {
if (meta.signed_) {
applyZeroPointScalar<int8_t>(size, workspacePtr, a, zp);
a = reinterpret_cast<float *>(workspacePtr);
workspacePtr += size * sizeof(float);

if (auto meta = info.a; meta.withZeroPoint) {
if (meta.scalar) {
if (meta.signed_) {
applyZeroPointScalar<int8_t>(size, a, input, zp);
} else {
applyZeroPointScalar<uint8_t>(size, a, input, zp);
}
} else {
applyZeroPointScalar<uint8_t>(size, workspacePtr, a, zp);
if (meta.signed_) {
applyZeroPointA<int8_t>(info.batch(), info.m, info.k, a, input, zp);
} else {
applyZeroPointA<uint8_t>(info.batch(), info.m, info.k, a, input, zp);
}
}
} else {
if (meta.signed_) {
applyZeroPointA<int8_t>(info.batch(), info.m, info.k, workspacePtr, a, zp);
applyCast<float, int8_t>(a, input, size);
} else {
applyZeroPointA<uint8_t>(info.batch(), info.m, info.k, workspacePtr, a, zp);
applyCast<float, uint8_t>(a, input, size);
}
}
a = workspacePtr;
workspacePtr += size;
} else if (!meta.signed_) {
auto size = info.batch() * info.m * info.k;
applyCast(size, workspacePtr, a);
a = workspacePtr;
workspacePtr += size;
}
if (auto meta = info.b; meta.withZeroPoint) {
{
auto size = info.batch() * info.k * info.n;
auto input = inputs[1];
auto zp = inputs[3];
if (meta.scalar) {
if (meta.signed_) {
applyZeroPointScalar<int8_t>(size, workspacePtr, b, zp);
b = reinterpret_cast<float *>(workspacePtr);
workspacePtr += size * sizeof(float);

if (auto meta = info.b; meta.withZeroPoint) {
if (meta.scalar) {
if (meta.signed_) {
applyZeroPointScalar<int8_t>(size, b, input, zp);
} else {
applyZeroPointScalar<uint8_t>(size, b, input, zp);
}
} else {
applyZeroPointScalar<uint8_t>(size, workspacePtr, b, zp);
if (meta.signed_) {
applyZeroPointA<int8_t>(info.batch(), info.m, info.k, b, input, zp);
} else {
applyZeroPointA<uint8_t>(info.batch(), info.m, info.k, b, input, zp);
}
}
} else {
if (meta.signed_) {
applyZeroPointA<int8_t>(info.batch(), info.k, info.n, workspacePtr, b, zp);
applyCast<float, int8_t>(b, input, size);
} else {
applyZeroPointA<uint8_t>(info.batch(), info.k, info.n, workspacePtr, b, zp);
applyCast<float, uint8_t>(b, input, size);
}
}
b = workspacePtr;
} else if (!meta.signed_) {
auto size = info.batch() * info.k * info.n;
applyCast(size, workspacePtr, b);
b = workspacePtr;
}
y = reinterpret_cast<float *>(workspacePtr);

auto handle = res.fetchOrStore<CublasContext>()->handle;
int32_t alpha = 1,
beta = 0;
float alpha = 1,
beta = 0;
auto m = info.m,
n = info.n,
k = info.k;
Expand All @@ -183,10 +189,10 @@ namespace refactor::kernel {
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
b + strideB * offset[1], CUDA_R_8I, ldb,
a + strideA * offset[0], CUDA_R_8I, lda,
&beta, y + strideY * i, CUDA_R_32I, n,
CUDA_R_32I, CUBLAS_GEMM_DEFAULT));
b + strideB * offset[1], CUDA_R_32F, ldb,
a + strideA * offset[0], CUDA_R_32F, lda,
&beta, y + strideY * i, CUDA_R_32F, n,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
}
} else {

Expand All @@ -195,12 +201,14 @@ namespace refactor::kernel {
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
b, CUDA_R_8I, ldb, strideB,
a, CUDA_R_8I, lda, strideA,
&beta, y, CUDA_R_32I, n,
b, CUDA_R_32F, ldb, strideB,
a, CUDA_R_32F, lda, strideA,
&beta, y, CUDA_R_32F, n,
strideY, info.batch(),
CUDA_R_32I, CUBLAS_GEMM_DEFAULT));
CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
}

applyCast<int32_t, float>(outputs[0], y, info.batch() * info.m * info.n);
};

res.fetchOrStore<CublasContext>();
Expand Down

0 comments on commit bca7336

Please sign in to comment.