Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for unpadded shapes in Matmul1D w/ gather_in0 #16627

Merged
merged 8 commits into from
Jan 24, 2025
Merged

Conversation

avoraTT
Copy link
Contributor

@avoraTT avoraTT commented Jan 10, 2025

Ticket

Problem description

In the current use case of Matmul1D with gather_in0 in the Llama models, the activations and weights need to be padded. This results in significant overhead.

What's changed

  • Added support to skip part of in0_block_w that is padding information
  • Pad the Kt and Nt in the host code for gather_in0

Checklist

@avoraTT avoraTT added the metal tt-metal issue label Jan 10, 2025
@avoraTT avoraTT self-assigned this Jan 10, 2025
@avoraTT avoraTT marked this pull request as ready for review January 10, 2025 21:38
@avoraTT avoraTT force-pushed the avora/mm_pad branch 3 times, most recently from 414e20d to a7aa7f3 Compare January 13, 2025 13:12
@avoraTT avoraTT requested a review from yugaoTT January 13, 2025 14:01
@avoraTT avoraTT force-pushed the avora/mm_pad branch 2 times, most recently from bc32be0 to 1d004e5 Compare January 21, 2025 14:34
Copy link
Contributor

@yugaoTT yugaoTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good!

@avoraTT avoraTT force-pushed the avora/mm_pad branch 2 times, most recently from 72e3729 to 52628da Compare January 21, 2025 18:55
Comment on lines +1749 to +1753
/* Inner dim padding */
const uint32_t Kt_pad = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1] * num_cores;
in0_block_w = Kt_pad / num_cores;

uint32_t num_blocks = Kt_pad / in0_block_w;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should just use passed in in0_block_w and derive num_blocks with shard spec

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in0_block_w also needs to change to be derived from the padded value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is just setting in0_block_w to be shard width. Is this not what you are passing in?

const uint32_t Kt_pad = in0_buffer->shard_spec().shape()[1] / in0_tile.get_tile_shape()[1] * num_cores;
in0_block_w = Kt_pad / num_cores;

num_blocks is basically unused in this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in0_block_w has to match with the unpadded K value, or else the other matmul validation fails. So that's why it needs to be updated here.

num_blocks is passed into the computer kernel to determine number of block iterations.

@avoraTT avoraTT force-pushed the avora/mm_pad branch 2 times, most recently from f453ac9 to 3f13274 Compare January 23, 2025 15:10
@avoraTT avoraTT requested a review from TT-BrianLiu January 23, 2025 15:36
@avoraTT avoraTT merged commit c8b0fa8 into main Jan 24, 2025
190 checks passed
@avoraTT avoraTT deleted the avora/mm_pad branch January 24, 2025 15:37
patrickroberts pushed a commit that referenced this pull request Jan 25, 2025
…6627)

### Ticket
- #16626

### Problem description
In the current use case of Matmul1D with gather_in0 in the Llama models,
the activations and weights need to be padded. This results in
significant overhead.

### What's changed
- Added support to skip part of in0_block_w that is padding information
- Pad the Kt and Nt in the host code for gather_in0

### Checklist
- [x] Post commit CI passes
(https://github.com/tenstorrent/tt-metal/actions/runs/12893880800)
- [x] New/Existing tests provide coverage for changes
(https://github.com/tenstorrent/tt-metal/actions/runs/12893883783)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
metal tt-metal issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants