diff --git a/csrc/layer_norm/README.md b/csrc/layer_norm/README.md index 934043e08..5914c9964 100644 --- a/csrc/layer_norm/README.md +++ b/csrc/layer_norm/README.md @@ -14,3 +14,7 @@ This extension has only been tested on A100s. ```sh cd csrc/layer_norm && pip install . ``` + +As of 2024-01-05, this extension is no longer used in the FlashAttention repo. +We've instead switched to a Triton-based +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py). diff --git a/flash_attn/models/bert.py b/flash_attn/models/bert.py index 4aaafdf01..33d693520 100644 --- a/flash_attn/models/bert.py +++ b/flash_attn/models/bert.py @@ -40,9 +40,10 @@ FusedDense = None try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm + from flash_attn.ops.triton.layer_norm import layer_norm_fn except ImportError: - dropout_add_layer_norm, layer_norm = None, None + layer_norm_fn = None + try: from flash_attn.losses.cross_entropy import CrossEntropyLoss @@ -237,8 +238,8 @@ def __init__(self, config): if fused_bias_fc and FusedDense is None: raise ImportError("fused_dense is not installed") self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) - if self.fused_dropout_add_ln and layer_norm is None: - raise ImportError("dropout_add_layer_norm is not installed") + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") linear_cls = nn.Linear if not fused_bias_fc else FusedDense self.dense = linear_cls(config.hidden_size, config.hidden_size) approximate = ( @@ -255,8 +256,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if not self.fused_dropout_add_ln: hidden_states = self.layer_norm(hidden_states) else: - hidden_states = layer_norm( - hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps + hidden_states = layer_norm_fn( + hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps ) return hidden_states @@ -345,8 +346,8 @@ def __init__(self, config: BertConfig, add_pooling_layer=True): config.vocab_size % self.pad_vocab_size_multiple ) self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) - if self.fused_dropout_add_ln and layer_norm is None: - raise ImportError("dropout_add_layer_norm is not installed") + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"] self.embeddings = BertEmbeddings( @@ -384,8 +385,8 @@ def forward( if not self.fused_dropout_add_ln: hidden_states = self.emb_ln(hidden_states) else: - hidden_states = layer_norm( - hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps + hidden_states = layer_norm_fn( + hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps ) hidden_states = self.emb_drop(hidden_states) diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 135213f53..6d4b6b188 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Tri Dao. +# Copyright (c) 2024, Tri Dao. import logging import math @@ -47,29 +47,14 @@ ColumnParallelLinear = None try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm -except ImportError: - dropout_add_layer_norm = None - -try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual -except ImportError: - dropout_add_layer_norm_parallel_residual = None - -try: - from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm -except ImportError: - RMSNorm, dropout_add_rms_norm = None, None - -try: - from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual + from flash_attn.ops.triton.mlp import FusedDenseSqreluDense except ImportError: - dropout_add_rms_norm_parallel_residual = None + FusedDenseSqreluDense = None try: - from flash_attn.ops.triton.mlp import FusedDenseSqreluDense + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm except ImportError: - FusedDenseSqreluDense = None + layer_norm_fn, RMSNorm = None, None logger = logging.getLogger(__name__) @@ -481,13 +466,15 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No for i in range(config.num_hidden_layers) ] ) + rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0) + if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache + for layer in self.layers[1:]: + layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False) if self.fused_dropout_add_ln: - if (not self.parallel_block and dropout_add_layer_norm is None) or ( - self.parallel_block and dropout_add_layer_norm_parallel_residual is None - ): - raise ImportError("dropout_layer_norm is not installed") + if layer_norm_fn is None: + raise ImportError("Triton is not installed") if self.prenorm: self.drop_f = nn.Dropout(config.resid_pdrop) norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm @@ -571,41 +558,17 @@ def forward(self, input_ids, position_ids=None, inference_params=None): hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) else: # Set prenorm=False here since we don't need the residual - if not self.parallel_block: - fused_add_norm_fn = ( - dropout_add_rms_norm - if isinstance(self.ln_f, RMSNorm) - else dropout_add_layer_norm - ) - hidden_states = fused_add_norm_fn( - hidden_states, - residual, - self.ln_f.weight, - self.ln_f.bias, - self.drop_f.p if self.training else 0.0, - self.ln_f.eps, - prenorm=False, - residual_in_fp32=self.residual_in_fp32, - ) - else: - fused_add_norm_fn = ( - dropout_add_rms_norm_parallel_residual - if isinstance(self.ln_f, RMSNorm) - else dropout_add_layer_norm_parallel_residual - ) - hidden_states, _ = fused_add_norm_fn( - hidden_states, - hidden_states2, - residual, - self.ln_f.weight, - self.ln_f.bias, - None, - None, - self.drop_f.p if self.training else 0.0, - self.ln_f.eps, - prenorm=False, - residual_in_fp32=self.residual_in_fp32, - ) + hidden_states = layer_norm_fn( + hidden_states, + self.ln_f.weight, + self.ln_f.bias, + residual=residual, + x1=None if not self.parallel_block else hidden_states2, + eps=self.ln_f.eps, + dropout_p=self.drop_f.p if self.training else 0.0, + prenorm=False, + is_rms_norm=isinstance(self.ln_f, RMSNorm) + ) return hidden_states diff --git a/flash_attn/models/vit.py b/flash_attn/models/vit.py index d1267ba7d..4602fd741 100644 --- a/flash_attn/models/vit.py +++ b/flash_attn/models/vit.py @@ -20,9 +20,9 @@ from flash_attn.modules.mlp import FusedMLP, Mlp try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm + from flash_attn.ops.triton.layer_norm import layer_norm_fn except ImportError: - dropout_add_layer_norm = None + layer_norm_fn = None def create_mixer_cls( @@ -229,8 +229,8 @@ def __init__( self.norm = norm_layer(embed_dim) self.fused_dropout_add_ln = fused_dropout_add_ln - if self.fused_dropout_add_ln and dropout_add_layer_norm is None: - raise ImportError("dropout_add_layer_norm is not installed") + if self.fused_dropout_add_ln and layer_norm_fn is None: + raise ImportError("Triton is not installed") # Classifier Head self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -302,16 +302,15 @@ def forward_features(self, x, all_tokens=True): ) ) # Set prenorm=False here since we don't need to the residual - hidden_states = dropout_add_layer_norm( + hidden_states = layer_norm_fn( hidden_states, - residual, self.norm.weight, self.norm.bias, - self.dropout.p if self.training else 0.0, - self.norm.eps, + residual=residual, + eps=self.norm.eps, + dropout_p=self.dropout.p if self.training else 0.0, rowscale=rowscale, prenorm=False, - residual_in_fp32=True, ) return hidden_states diff --git a/flash_attn/modules/block.py b/flash_attn/modules/block.py index c8907c7c6..be8e8b864 100644 --- a/flash_attn/modules/block.py +++ b/flash_attn/modules/block.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, Tri Dao. +# Copyright (c) 2024, Tri Dao. from functools import partial from typing import Optional @@ -13,24 +13,9 @@ from flash_attn.modules.mlp import Mlp try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm except ImportError: - dropout_add_layer_norm = None - -try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual -except ImportError: - dropout_add_layer_norm_parallel_residual = None - -try: - from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm -except ImportError: - RMSNorm, dropout_add_rms_norm = None, None - -try: - from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual -except ImportError: - dropout_add_rms_norm_parallel_residual = None + layer_norm_fn, RMSNorm = None, None class Block(nn.Module): @@ -91,8 +76,7 @@ def __init__( self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: - assert dropout_add_layer_norm is not None, "dropout_layer_norm is not installed" - assert dropout_add_rms_norm is not None, "dropout_layer_norm is not installed" + assert layer_norm_fn is not None, "Triton is not installed" assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( self.dropout1, nn.Dropout ) @@ -137,11 +121,6 @@ def forward( before applying the query projection. Useful for e.g., ViT where we only care about the CLS token in the last layer. """ - fused_add_norm_fn = ( - dropout_add_rms_norm - if RMSNorm and isinstance(self.norm1, RMSNorm) - else dropout_add_layer_norm - ) if self.prenorm: if not self.fused_dropout_add_ln: dropped = self.drop_path1(self.dropout1(hidden_states)) @@ -160,16 +139,17 @@ def forward( dtype=hidden_states.dtype, ) ) - hidden_states, residual = fused_add_norm_fn( + hidden_states, residual = layer_norm_fn( hidden_states, - residual, self.norm1.weight, self.norm1.bias, - self.dropout1.p if self.training else 0.0, - self.norm1.eps, + residual=residual, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, rowscale=rowscale1, prenorm=True, residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm1, RMSNorm) ) if mixer_kwargs is None: mixer_kwargs = {} @@ -196,16 +176,17 @@ def forward( dtype=hidden_states.dtype, ) ) - hidden_states, residual = fused_add_norm_fn( + hidden_states, residual = layer_norm_fn( hidden_states, - residual, self.norm2.weight, self.norm2.bias, - self.dropout2.p if self.training else 0.0, - self.norm2.eps, + residual=residual, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, rowscale=rowscale2, prenorm=True, residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm) ) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -231,15 +212,16 @@ def forward( mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype ) ) - hidden_states = fused_add_norm_fn( + hidden_states = layer_norm_fn( mixer_out, - hidden_states, self.norm1.weight, self.norm1.bias, - self.dropout1.p if self.training else 0.0, - self.norm1.eps, + residual=hidden_states, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, rowscale=rowscale1, prenorm=False, + is_rms_norm=isinstance(self.norm1, RMSNorm) ) if not isinstance(self.mlp, nn.Identity): mlp_out = self.mlp(hidden_states) @@ -260,15 +242,16 @@ def forward( mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype ) ) - hidden_states = fused_add_norm_fn( + hidden_states = layer_norm_fn( mlp_out, - hidden_states, self.norm2.weight, self.norm2.bias, - self.dropout2.p if self.training else 0.0, - self.norm2.eps, + residual=hidden_states, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, rowscale=rowscale2, prenorm=False, + is_rms_norm=isinstance(self.norm2, RMSNorm) ) return hidden_states @@ -320,12 +303,7 @@ def __init__( self.norm2 = norm_cls(dim) if self.fused_dropout_add_ln: - assert ( - dropout_add_layer_norm_parallel_residual is not None - ), "dropout_layer_norm is not installed" - assert ( - dropout_add_rms_norm_parallel_residual is not None - ), "dropout_layer_norm is not installed" + assert layer_norm_fn is not None, "Triton is not installed" assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( self.dropout1, nn.Dropout ) @@ -370,11 +348,6 @@ def forward( """ # TODO: Ideally we should only do the allgather / allreduce once for # the Linear to MLP & Attention - fused_add_norm_fn = ( - dropout_add_rms_norm_parallel_residual - if isinstance(self.norm1, RMSNorm) - else dropout_add_layer_norm_parallel_residual - ) if not self.fused_dropout_add_ln: dropped1 = self.dropout1(hidden_states1) # For the very 1st block, we only want 1 dropout, not two different dropouts @@ -399,21 +372,24 @@ def forward( weight2, bias2 = ( (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) ) - hidden_states1, hidden_states2, residual = fused_add_norm_fn( + hidden_states1, *rest, residual = layer_norm_fn( hidden_states1, - hidden_states2, - residual, self.norm1.weight, self.norm1.bias, - weight2, - bias2, - self.dropout1.p if self.training else 0.0, - self.norm1.eps, + residual=residual, + x1=hidden_states2, + weight1=weight2, + bias1=bias2, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, prenorm=True, residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm1, RMSNorm) ) if self.tied_norm: hidden_states2 = hidden_states1 + else: + hidden_states2, = rest if mixer_kwargs is None: mixer_kwargs = {} hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) diff --git a/training/Dockerfile b/training/Dockerfile index 4de0d3125..fe7d12acd 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -87,9 +87,5 @@ RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention RUN pip install flash-attn==2.4.2 -# Install CUDA extensions for fused dense, layer norm -RUN git clone https://github.com/HazyResearch/flash-attention \ - && cd flash-attention && git checkout v2.4.2 \ - && cd csrc/layer_norm && pip install . && cd ../../ \ - && cd csrc/fused_dense_lib && pip install . && cd ../../ \ - && cd .. && rm -rf flash-attention +# Install CUDA extensions for fused dense +RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.4.2#subdirectory=csrc/fused_dense_lib