-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Ehud Sharlin <[email protected]>
- Loading branch information
1 parent
5e17aa6
commit 742695e
Showing
1 changed file
with
6 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,7 +90,12 @@ We baked in the following AMD-specific optimizations when writing the Mamba-2 ba | |
|
||
Similar to FA2, we achieve speedups on the Mamba2 backward kernel of 4%, 5%, and 6% for sequence lengths of 2k, 4k, and 8k, respectively on MI300X compared to the H100. Cache thrashing, data movement cost, and SM utilization are all significantly improved. With both the Mamba2 forwards and backwards and the Flash Attention 2 backwards kernel in-hand, pure-SSM and hybrid attention/SSM models are trainable on MI300X hardware, and can achieve higher FLOPs per dollar than is possible on NVIDIA H100 systems. | ||
|
||
## Future Work | ||
## Summary | ||
|
||
In this blog post we outlined Zyphra’s vision of training transformers and hybrid models at a lower cost. We explain how Zyphra is realizing this vision by optimizing the superior hardware specifications of the AMD Instinct MI300X Accelerators, using ROCm to train Zyphra’s hybrid models: the Mamba2 and the Flash Attention v2. | ||
|
||
As future work, Zyphra plans to extend the attention kernel and portions of the Mamba2 kernel to fp8 precision, and enable fine-grained tensor-parallel overlap within the Mamba2, Attention, and MLP blocks with communication. Both optimizations are critical to Zyphra’s training pipeline. | ||
|
||
|
||
Check failure on line 99 in blogs/artificial-intelligence/mamba2-flash-attention-kernels/README.md GitHub Actions / Documentation / MarkdownMultiple consecutive blank lines
|
||
As future work, we plan to extend the attention kernel and portions of the Mamba2 kernel to fp8 precision, and enable fine-grained tensor-parallel overlap within the Mamba2, Attention, and MLP blocks with communication. Both optimizations are critical to Zyphra’s training pipeline. | ||
|
||
|