Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add int8 to gemm w/ addmatrix #3040

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

alexbaden
Copy link
Contributor

@alexbaden alexbaden commented Dec 18, 2024

Update the gemm addmatrix benchmark to support int8 inputs as well as bfloat16.

The int8 benchmark is pretty slow - not because Triton performance is bad (it is at least on par with bfloat16) but because PyTorch does not support int8 matmul on GPU, so we have to do the matmul on the GPU. This makes the benchmark something like 20x slower. To fix that, I changed the PyTorch accuracy check to only run for a few shapes instead of all the shapes - I tried to pick shapes that I thought were representative of different cases but am open to suggestions. Now the benchmark runs in reasonable time.

A few open items need to be addressed:

  • for int8 we want a separate geomean (i.e. geomean of bfloat16, and geomean of int8). What's the best way to keep int8 and bfloat16 separate? I can introduce an environment variable and run the benchmark twice - once only bfloat16, once only int8. Open to other suggestions.
  • for onednn comparison, there is no problem with bfloat16 but there is no support for GPU matmul w/ int8. I don't think we want to run the comparison vs CPU (it takes too long and gives us no info), so I might need to introduce the env variable anyway to do one run which is float16 w/ onednn, and another run which is int8 w/out onednn.

cc #3014

@Egor-Krivov
Copy link
Contributor

Egor-Krivov commented Dec 18, 2024

I think we treat them as 2 separate benchmarks in terms of reporting (So we have 2 lines of --benchmark gemm-postop-addmatrix & --benchmark gemm-postop-addmatrix-int8). Then all geomeans will work as intended. Otherwise we'll have to introduce some geomean groups in our database.

Then we either run benchmark script twice with different dypes to generate 2 separate report files (I'd prefer this), or modify our report script to add some filtering capability.

About onednn and int8 support. Do we want to measure onednn? If not, and we run it only for validation, maybe we could run it in other precision, like fp32 or bf16 just for validation of the output.

@Egor-Krivov
Copy link
Contributor

My only issue with this PR right now is that all charts and GeoMeans for addmatrix benchmark will be discontinued, due to new parameters. Hence, my suggestion to introduce separate benchmark for int8

@vlad-penkin vlad-penkin linked an issue Dec 18, 2024 that may be closed by this pull request
add int8 to addmatrix benchmark 2/?

add int8 to addmatrix benchmark 3/3
@alexbaden alexbaden force-pushed the alex/block_adds_benchmark branch from 164fd00 to 3cc0b6e Compare January 6, 2025 15:29
@alexbaden alexbaden changed the title Add int8 to gemm w/ addmatrix and consider onednn provider Add int8 to gemm w/ addmatrix Jan 6, 2025
@alexbaden
Copy link
Contributor Author

I removed the onednn related bits because for onednn we only measure kernel time, and the add step is not fused into the main gemm kernel so the comparison would not be appropriate.

I modified the int8 code to only run as a separate benchmark step controllable by environment variable, so the existing bfloat16 time should not be affected. The default run mode is bfloat16, with an optional int8 mode or all dtypes mode for local runs.

@alexbaden alexbaden marked this pull request as ready for review January 6, 2025 15:30
@@ -117,7 +130,7 @@ def matmul_kernel_with_block_pointers_batched(
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr,
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, ACCUMULATOR_DTYPE: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, ACCUMULATOR_DTYPE: tl.constexpr,
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, #
ACCUMULATOR_DTYPE: tl.constexpr,

],
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + #
[[*shape, dtype]
for shape in [[1, 1, 5120, 13824], #
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] easier to command out shapes when debugging

Suggested change
for shape in [[1, 1, 5120, 13824], #
for shape in [ #
[1, 1, 5120, 13824], #

[4, 32768, 4096, 128], #
[32, 4096, 4096, 128], #
[4096, 8, 128, 16384], #
[4096, 8, 16384, 128]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] easier to command out shapes when debugging

Suggested change
[4096, 8, 16384, 128]]
[4096, 8, 16384, 128] #
]

@@ -247,29 +262,42 @@ def matmul(a, b, d, c):
# name for the plot. Used also as a file name for saving the plot.
args={},
))
def benchmark(B, M, N, K, provider):
def benchmark(B, M, N, K, dtype, provider):
res_dtype = torch.float32 if dtype is torch.bfloat16 else torch.int32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent with the code above?

Suggested change
res_dtype = torch.float32 if dtype is torch.bfloat16 else torch.int32
res_dtype = torch.float32 if a.dtype.is_floating_point else torch.int32

benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=1e-4, rtol=rtol, err_msg='triton to torch')
if dtype.is_floating_point or [B, M, N, K] in [[1, 1024, 1024, 1024], [1, 2048, 2048, 2048],
[1, 512, 8192, 32768], [4, 32768, 4096, 128]]:
# torch int8 matmul on GPU is not supported. only check a few int8 shapes to reduce runtime
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have a env var for all benchmarks to control if we verify the result?
Don't think we should skip checking correctness for some shapes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add onednn to gemm benchmarks
3 participants