Skip to content

Commit

Permalink
[LayerNorm] Switch from CUDA to Triton implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 5, 2024
1 parent f5b308e commit abbc131
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 143 deletions.
4 changes: 4 additions & 0 deletions csrc/layer_norm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@ This extension has only been tested on A100s.
```sh
cd csrc/layer_norm && pip install .
```

As of 2024-01-05, this extension is no longer used in the FlashAttention repo.
We've instead switched to a Triton-based
[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/layer_norm.py).
21 changes: 11 additions & 10 deletions flash_attn/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@
FusedDense = None

try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm, layer_norm
from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError:
dropout_add_layer_norm, layer_norm = None, None
layer_norm_fn = None


try:
from flash_attn.losses.cross_entropy import CrossEntropyLoss
Expand Down Expand Up @@ -237,8 +238,8 @@ def __init__(self, config):
if fused_bias_fc and FusedDense is None:
raise ImportError("fused_dense is not installed")
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("Triton is not installed")
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
self.dense = linear_cls(config.hidden_size, config.hidden_size)
approximate = (
Expand All @@ -255,8 +256,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if not self.fused_dropout_add_ln:
hidden_states = self.layer_norm(hidden_states)
else:
hidden_states = layer_norm(
hidden_states, self.layer_norm.weight, self.layer_norm.bias, self.layer_norm.eps
hidden_states = layer_norm_fn(
hidden_states, self.layer_norm.weight, self.layer_norm.bias, eps=self.layer_norm.eps
)
return hidden_states

Expand Down Expand Up @@ -345,8 +346,8 @@ def __init__(self, config: BertConfig, add_pooling_layer=True):
config.vocab_size % self.pad_vocab_size_multiple
)
self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln and layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("Triton is not installed")
assert config.hidden_act in ["gelu", "gelu_new", "gelu_fast", "gelu_pytorch_tanh"]

self.embeddings = BertEmbeddings(
Expand Down Expand Up @@ -384,8 +385,8 @@ def forward(
if not self.fused_dropout_add_ln:
hidden_states = self.emb_ln(hidden_states)
else:
hidden_states = layer_norm(
hidden_states, self.emb_ln.weight, self.emb_ln.bias, self.emb_ln.eps
hidden_states = layer_norm_fn(
hidden_states, self.emb_ln.weight, self.emb_ln.bias, eps=self.emb_ln.eps
)
hidden_states = self.emb_drop(hidden_states)

Expand Down
81 changes: 22 additions & 59 deletions flash_attn/models/gpt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2024, Tri Dao.

import logging
import math
Expand Down Expand Up @@ -47,29 +47,14 @@
ColumnParallelLinear = None

try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
dropout_add_layer_norm = None

try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm_parallel_residual
except ImportError:
dropout_add_layer_norm_parallel_residual = None

try:
from flash_attn.ops.rms_norm import RMSNorm, dropout_add_rms_norm
except ImportError:
RMSNorm, dropout_add_rms_norm = None, None

try:
from flash_attn.ops.rms_norm import dropout_add_rms_norm_parallel_residual
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
except ImportError:
dropout_add_rms_norm_parallel_residual = None
FusedDenseSqreluDense = None

try:
from flash_attn.ops.triton.mlp import FusedDenseSqreluDense
from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm
except ImportError:
FusedDenseSqreluDense = None
layer_norm_fn, RMSNorm = None, None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -481,13 +466,15 @@ def __init__(self, config: GPT2Config, process_group=None, device=None, dtype=No
for i in range(config.num_hidden_layers)
]
)
rotary_emb_fraction = getattr(config, "rotary_emb_fraction", 0.0)
if rotary_emb_fraction > 0.0: # Tie all the RotaryEmbedding modules to share the same cos/sin cache
for layer in self.layers[1:]:
layer.mixer.rotary_emb = self.layers[0].mixer.rotary_emb

self.fused_dropout_add_ln = getattr(config, "fused_dropout_add_ln", False)
if self.fused_dropout_add_ln:
if (not self.parallel_block and dropout_add_layer_norm is None) or (
self.parallel_block and dropout_add_layer_norm_parallel_residual is None
):
raise ImportError("dropout_layer_norm is not installed")
if layer_norm_fn is None:
raise ImportError("Triton is not installed")
if self.prenorm:
self.drop_f = nn.Dropout(config.resid_pdrop)
norm_cls = nn.LayerNorm if not use_rms_norm else RMSNorm
Expand Down Expand Up @@ -571,41 +558,17 @@ def forward(self, input_ids, position_ids=None, inference_params=None):
hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype))
else:
# Set prenorm=False here since we don't need the residual
if not self.parallel_block:
fused_add_norm_fn = (
dropout_add_rms_norm
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm
)
hidden_states = fused_add_norm_fn(
hidden_states,
residual,
self.ln_f.weight,
self.ln_f.bias,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
else:
fused_add_norm_fn = (
dropout_add_rms_norm_parallel_residual
if isinstance(self.ln_f, RMSNorm)
else dropout_add_layer_norm_parallel_residual
)
hidden_states, _ = fused_add_norm_fn(
hidden_states,
hidden_states2,
residual,
self.ln_f.weight,
self.ln_f.bias,
None,
None,
self.drop_f.p if self.training else 0.0,
self.ln_f.eps,
prenorm=False,
residual_in_fp32=self.residual_in_fp32,
)
hidden_states = layer_norm_fn(
hidden_states,
self.ln_f.weight,
self.ln_f.bias,
residual=residual,
x1=None if not self.parallel_block else hidden_states2,
eps=self.ln_f.eps,
dropout_p=self.drop_f.p if self.training else 0.0,
prenorm=False,
is_rms_norm=isinstance(self.ln_f, RMSNorm)
)
return hidden_states


Expand Down
17 changes: 8 additions & 9 deletions flash_attn/models/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
from flash_attn.modules.mlp import FusedMLP, Mlp

try:
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from flash_attn.ops.triton.layer_norm import layer_norm_fn
except ImportError:
dropout_add_layer_norm = None
layer_norm_fn = None


def create_mixer_cls(
Expand Down Expand Up @@ -229,8 +229,8 @@ def __init__(
self.norm = norm_layer(embed_dim)

self.fused_dropout_add_ln = fused_dropout_add_ln
if self.fused_dropout_add_ln and dropout_add_layer_norm is None:
raise ImportError("dropout_add_layer_norm is not installed")
if self.fused_dropout_add_ln and layer_norm_fn is None:
raise ImportError("Triton is not installed")

# Classifier Head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
Expand Down Expand Up @@ -302,16 +302,15 @@ def forward_features(self, x, all_tokens=True):
)
)
# Set prenorm=False here since we don't need to the residual
hidden_states = dropout_add_layer_norm(
hidden_states = layer_norm_fn(
hidden_states,
residual,
self.norm.weight,
self.norm.bias,
self.dropout.p if self.training else 0.0,
self.norm.eps,
residual=residual,
eps=self.norm.eps,
dropout_p=self.dropout.p if self.training else 0.0,
rowscale=rowscale,
prenorm=False,
residual_in_fp32=True,
)
return hidden_states

Expand Down
Loading

0 comments on commit abbc131

Please sign in to comment.