-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add int8 to gemm w/ addmatrix #3040
base: main
Are you sure you want to change the base?
Conversation
I think we treat them as 2 separate benchmarks in terms of reporting (So we have 2 lines of Then we either run benchmark script twice with different dypes to generate 2 separate report files (I'd prefer this), or modify our report script to add some filtering capability. About onednn and int8 support. Do we want to measure onednn? If not, and we run it only for validation, maybe we could run it in other precision, like fp32 or bf16 just for validation of the output. |
My only issue with this PR right now is that all charts and GeoMeans for addmatrix benchmark will be discontinued, due to new parameters. Hence, my suggestion to introduce separate benchmark for int8 |
add int8 to addmatrix benchmark 2/? add int8 to addmatrix benchmark 3/3
164fd00
to
3cc0b6e
Compare
I removed the onednn related bits because for onednn we only measure kernel time, and the add step is not fused into the main gemm kernel so the comparison would not be appropriate. I modified the int8 code to only run as a separate benchmark step controllable by environment variable, so the existing bfloat16 time should not be affected. The default run mode is bfloat16, with an optional int8 mode or all dtypes mode for local runs. |
@@ -117,7 +130,7 @@ def matmul_kernel_with_block_pointers_batched( | |||
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, # | |||
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, # | |||
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, # | |||
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, | |||
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, ACCUMULATOR_DTYPE: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit]
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, ACCUMULATOR_DTYPE: tl.constexpr, | |
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, # | |
ACCUMULATOR_DTYPE: tl.constexpr, |
], | ||
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + # | ||
[[*shape, dtype] | ||
for shape in [[1, 1, 5120, 13824], # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] easier to command out shapes when debugging
for shape in [[1, 1, 5120, 13824], # | |
for shape in [ # | |
[1, 1, 5120, 13824], # |
[4, 32768, 4096, 128], # | ||
[32, 4096, 4096, 128], # | ||
[4096, 8, 128, 16384], # | ||
[4096, 8, 16384, 128]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] easier to command out shapes when debugging
[4096, 8, 16384, 128]] | |
[4096, 8, 16384, 128] # | |
] |
@@ -247,29 +262,42 @@ def matmul(a, b, d, c): | |||
# name for the plot. Used also as a file name for saving the plot. | |||
args={}, | |||
)) | |||
def benchmark(B, M, N, K, provider): | |||
def benchmark(B, M, N, K, dtype, provider): | |||
res_dtype = torch.float32 if dtype is torch.bfloat16 else torch.int32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be consistent with the code above?
res_dtype = torch.float32 if dtype is torch.bfloat16 else torch.int32 | |
res_dtype = torch.float32 if a.dtype.is_floating_point else torch.int32 |
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch') | ||
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048], | ||
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]: | ||
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we have a env var for all benchmarks to control if we verify the result?
Don't think we should skip checking correctness for some shapes.
Update the gemm addmatrix benchmark to support int8 inputs as well as bfloat16.
The int8 benchmark is pretty slow - not because Triton performance is bad (it is at least on par with bfloat16) but because PyTorch does not support int8 matmul on GPU, so we have to do the matmul on the GPU. This makes the benchmark something like 20x slower. To fix that, I changed the PyTorch accuracy check to only run for a few shapes instead of all the shapes - I tried to pick shapes that I thought were representative of different cases but am open to suggestions. Now the benchmark runs in reasonable time.
A few open items need to be addressed:
cc #3014