-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make float32_qk_product and float32_logits apply during inference #1225
base: main
Are you sure you want to change the base?
Make float32_qk_product and float32_logits apply during inference #1225
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I would like to keep the model yml files to be core properties of the model - we follow this pattern at least for the TPU yml files
MaxText/configs/base.yml
Outdated
@@ -192,6 +192,10 @@ final_logits_soft_cap: 0.0 | |||
use_post_attn_norm: False | |||
use_post_ffw_norm: False | |||
|
|||
# In dot_product attention, whether to upcast the qk product and attention logits to fp32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we move these next to similar options in line 123?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do
MaxText/configs/models/gemma-2b.yml
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer that the quantization settings are are not part of the /models - the precision isn't a property of the model, the user can still run the gemma model with different precision settings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, I'll remove these from the model configs
@@ -477,7 +477,7 @@ def apply_attention_dot( | |||
"""Apply Attention.""" | |||
validate_compute_axis_order(self.compute_axis_order) | |||
# Casting qk_product and softmaxt computation for float32 for model stability. | |||
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product: | |||
if self.float32_qk_product: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you have to set precision as well for float32 to actually take effect https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision - we have this option in maxtext
maxtext/MaxText/configs/base.yml
Line 79 in d33821f
matmul_precision: "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point which I hadn't considered. I suppose there are two moving parts here: are the flops performed in bf16 or fp32, and are the results accumulated in bf16 or fp32. I think these are generally controlled with the precision
and preferred_element_type
arguments, respectively.
It appears that on default precision the flops happen in bf16 even when one or more of the inputs are in fp32. However, the accumulation can still happen in fp32, and that seems to have been enough to solve our particular problem. In particular, the compiler seems to recognize that even though the python says to upcast to fp32, it can elide that because it's going to do the computation. However, it still outputs fp32.
This is the qk product with float32_qk_product=False
And this is with float32_qk_product=True
(note the output type is now f32)
I'm not 100% confident in my interpretation of those graphs, but this would explain why it takes longer even without changing the precision parameter.
Separately, it looks like matmul_precision
consistently gets routed into DenseGeneral
usages, but not into the raw einsums used in qk_product
and wv_product
. When I change matmul_precision
in the config it does not affect the runtime of those operations, but if I add it explicitly to the einsums then the wv_product
does take longer, which makes sense. Is that something we should fix just by adding those arguments to the einsums?
Description
Some models are sensitive to the precision of the QK product and the attention logits. Right now, gemma and gemma2 upcast these to fp32 while other models do not.
When we implemented qwen2-72b, we observed degraded performance on downstream evals (~40% success vs 70% success on a particular eval) compared to another implementation. We only observed this with dot_product attention.
We traced this to the softmax following the qk product, where bf16 precision was not sufficient. We tried to duplicate gemma's configuraiton to upcast these to fp32, but this is disabled for inference.
This PR removes the condition that upcasting only applies for training, and it lifts those flags into model configuration. The latter is optional, but we've found it more convenient to make the decision at that level. This would also allow you to supply different values at training and inference time, replicating the previous behavior.
Notably, these flags only apply to dot_product attention, which is used during decoding but usually avoided during training. If I'm reading this right, the flash attention implementation forces them to fp32 (here and, implicitly, here). Arguably that suggests the default for these flags should be True. However, in very basic profiling, these flags increased attention costs by approximately 50%, so I'm not sure that's justified.
Tests
These changes fixed our downstream issues, as described above.
Checklist
Before submitting this PR, please make sure (put X in square brackets):