generated from dtreai/Python-Template
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathusage.py
64 lines (55 loc) · 1.78 KB
/
usage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from tri_rmsnorm.kernel.rms_normalization_kernel import (
_rms_norm_fwd_fused,
_rms_norm_bwd_dx_fused,
)
class RMSNormFunctionCustomKernel(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias, eps):
M, N = x.shape
y = torch.empty_like(x)
rstd = torch.empty(M, dtype=torch.float32, device=x.device)
_rms_norm_fwd_fused[(M,)](x, y, weight, bias, rstd, x.stride(0), N, eps, BLOCK_SIZE=1024)
ctx.save_for_backward(x, weight, bias, rstd)
ctx.eps = eps
ctx.N = N
return y
@staticmethod
def backward(ctx, dy):
x, weight, bias, rstd = ctx.saved_tensors
eps = ctx.eps
N = ctx.N
M = x.shape[0]
dx = torch.empty_like(x)
_dw = torch.empty_like(weight)
_db = torch.empty_like(bias)
locks = torch.zeros(2 * 32, dtype=torch.int32, device=x.device)
_rms_norm_bwd_dx_fused[(M,)](
dx,
dy,
_dw,
_db,
x,
weight,
bias,
rstd,
locks,
x.stride(0),
N,
eps,
GROUP_SIZE_M=32,
BLOCK_SIZE_N=1024,
)
return dx, _dw, _db, None
def test_rms_norm_custom_kernel():
eps = 1e-5
input = torch.tensor([[0.1, -0.2] * 10] * 10, device="cuda", requires_grad=True)
weights = torch.tensor([0.1] * 20, device="cuda", requires_grad=True)
biases = torch.tensor([0.01] * 20, device="cuda", requires_grad=True)
output = RMSNormFunctionCustomKernel.apply(input, weights, biases, eps)
loss = output.mean()
loss.backward()
print("Grads X: ", input.grad)
print("Grads W: ", weights.grad)
print("Grads B: ", biases.grad)
test_rms_norm_custom_kernel()