-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathkernel.cu
87 lines (68 loc) · 2.93 KB
/
kernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cutlass/gemm/device/gemm.h>
torch::Tensor int4MatmulCUDA(const torch::Tensor &A, const torch::Tensor &B) {
torch::checkAllSameGPU("int4Matmul", {{A, "A", 0}, {B, "B", 1}});
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1) * 2; // = B.size(1) * 2 . 4bit packing is on the columns
auto C = torch::empty({M, N}, torch::dtype(torch::kInt32).device(A.device()));
using Gemm = cutlass::gemm::device::Gemm<
cutlass::int4b_t, // ElementA
cutlass::layout::RowMajor, // LayoutA
cutlass::int4b_t, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
int32_t, // ElementOutput
cutlass::layout::RowMajor, // LayoutOutput
int32_t, // ElementAccumulator
cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
cutlass::arch::Sm75 // tag indicating target GPU compute architecture
>;
Gemm gemmOp;
using GemmCoord = cutlass::gemm::GemmCoord;
typename Gemm::Arguments arguments{
{static_cast<GemmCoord::Index>(M), static_cast<GemmCoord::Index>(N),
static_cast<GemmCoord::Index>(K)},
{(cutlass::int4b_t *)A.data_ptr<uint8_t>(), K},
{(cutlass::int4b_t *)B.data_ptr<uint8_t>(), K},
{C.data_ptr<int32_t>(), N},
{C.data_ptr<int32_t>(), N},
{1, 0}};
auto status = gemmOp(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess,
cutlassGetStatusString(status))
return C;
}
torch::Tensor int8MatmulCUDA(const torch::Tensor &A, const torch::Tensor &B) {
torch::checkAllSameGPU("int8Matmul", {{A, "A", 0}, {B, "B", 1}});
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1); // = B.size(1)
auto C = torch::empty({M, N}, torch::dtype(torch::kInt32).device(A.device()));
using Gemm = cutlass::gemm::device::Gemm<
int8_t, // ElementA
cutlass::layout::RowMajor, // LayoutA
int8_t, // ElementB
cutlass::layout::ColumnMajor, // LayoutB
int32_t, // ElementOutput
cutlass::layout::RowMajor, // LayoutOutput
int32_t, // ElementAccumulator
cutlass::arch::OpClassTensorOp, // tag indicating Tensor Cores
cutlass::arch::Sm75 // tag indicating target GPU compute architecture
>;
Gemm gemmOp;
using GemmCoord = cutlass::gemm::GemmCoord;
typename Gemm::Arguments arguments{
{static_cast<GemmCoord::Index>(M), static_cast<GemmCoord::Index>(N),
static_cast<GemmCoord::Index>(K)},
{A.data_ptr<int8_t>(), K},
{B.data_ptr<int8_t>(), K},
{C.data_ptr<int32_t>(), N},
{C.data_ptr<int32_t>(), N},
{1, 0}};
auto status = gemmOp(arguments);
TORCH_CHECK(status == cutlass::Status::kSuccess,
cutlassGetStatusString(status))
return C;
}