Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make float32_qk_product and float32_logits apply during inference #1225

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

philip-essential
Copy link
Contributor

Description

Some models are sensitive to the precision of the QK product and the attention logits. Right now, gemma and gemma2 upcast these to fp32 while other models do not.

When we implemented qwen2-72b, we observed degraded performance on downstream evals (~40% success vs 70% success on a particular eval) compared to another implementation. We only observed this with dot_product attention.

We traced this to the softmax following the qk product, where bf16 precision was not sufficient. We tried to duplicate gemma's configuraiton to upcast these to fp32, but this is disabled for inference.

This PR removes the condition that upcasting only applies for training, and it lifts those flags into model configuration. The latter is optional, but we've found it more convenient to make the decision at that level. This would also allow you to supply different values at training and inference time, replicating the previous behavior.

Notably, these flags only apply to dot_product attention, which is used during decoding but usually avoided during training. If I'm reading this right, the flash attention implementation forces them to fp32 (here and, implicitly, here). Arguably that suggests the default for these flags should be True. However, in very basic profiling, these flags increased attention costs by approximately 50%, so I'm not sure that's justified.

Tests

These changes fixed our downstream issues, as described above.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I would like to keep the model yml files to be core properties of the model - we follow this pattern at least for the TPU yml files

@@ -192,6 +192,10 @@ final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False

# In dot_product attention, whether to upcast the qk product and attention logits to fp32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we move these next to similar options in line 123?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer that the quantization settings are are not part of the /models - the precision isn't a property of the model, the user can still run the gemma model with different precision settings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll remove these from the model configs

@@ -477,7 +477,7 @@ def apply_attention_dot(
"""Apply Attention."""
validate_compute_axis_order(self.compute_axis_order)
# Casting qk_product and softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product:
if self.float32_qk_product:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have to set precision as well for float32 to actually take effect https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision - we have this option in maxtext

matmul_precision: "default"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point which I hadn't considered. I suppose there are two moving parts here: are the flops performed in bf16 or fp32, and are the results accumulated in bf16 or fp32. I think these are generally controlled with the precision and preferred_element_type arguments, respectively.

It appears that on default precision the flops happen in bf16 even when one or more of the inputs are in fp32. However, the accumulation can still happen in fp32, and that seems to have been enough to solve our particular problem. In particular, the compiler seems to recognize that even though the python says to upcast to fp32, it can elide that because it's going to do the computation. However, it still outputs fp32.

This is the qk product with float32_qk_product=False
Screenshot 2025-01-31 at 4 53 15 PM

And this is with float32_qk_product=True (note the output type is now f32)
Screenshot 2025-01-31 at 3 14 11 PM

I'm not 100% confident in my interpretation of those graphs, but this would explain why it takes longer even without changing the precision parameter.

Separately, it looks like matmul_precision consistently gets routed into DenseGeneral usages, but not into the raw einsums used in qk_product and wv_product. When I change matmul_precision in the config it does not affect the runtime of those operations, but if I add it explicitly to the einsums then the wv_product does take longer, which makes sense. Is that something we should fix just by adding those arguments to the einsums?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants