Skip to content

Commit

Permalink
remove dependency of flash_attn when use_flash_attn is set to false (#20
Browse files Browse the repository at this point in the history
)

Co-authored-by: geruijun <[email protected]>
  • Loading branch information
sallyjunjun and sallyjunjun authored Mar 8, 2024
1 parent fb6a587 commit 0dcc0e9
Show file tree
Hide file tree
Showing 13 changed files with 827 additions and 146 deletions.
346 changes: 319 additions & 27 deletions internlm/model/embedding.py

Large diffs are not rendered by default.

111 changes: 104 additions & 7 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from typing import Callable, Optional

import torch
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from torch.distributed import ProcessGroup

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import (
Silu,
all_reduce,
fused_dense_func,
isp_fused_dense_func,
megatron_fused_dense_func,
reduce_scatter,
)
from internlm.utils.logger import get_logger

Expand Down Expand Up @@ -202,7 +203,47 @@ def forward(self, input): # pylint: disable=W0622
)


class ColumnParallelLinearTorch(ColumnParallelLinear):
class ColumnParallelLinearTorch(nn.Linear):
"""
ColumnParallelLinearTorch.
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
"""

def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % multiple_of:
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
multiple = out_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

def forward(self, x, gather_dim=0):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
Expand All @@ -217,7 +258,11 @@ def forward(self, x, gather_dim=0):
)


class MegatronColumnParallelLinearTorch(ColumnParallelLinear):
class MegatronColumnParallelLinearTorch(ColumnParallelLinearTorch):
"""
MegatronColumnParallelLinearTorch
"""

def forward(self, x, gather_dim=0):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
Expand All @@ -232,7 +277,55 @@ def forward(self, x, gather_dim=0):
)


class RowParallelLinearTorch(RowParallelLinear):
class RowParallelLinearTorch(nn.Linear):
"""
RowParallelLinearTorch.
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
"""

def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % multiple_of:
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
multiple = in_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias
super().__init__(
local_multiple * multiple_of,
out_features,
bias=bias and rank == 0,
device=device,
dtype=dtype,
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
Expand All @@ -243,7 +336,11 @@ def forward(self, x):
return reduce_fn(out, self.process_group)


class MegatronRowParallelLinearTorch(RowParallelLinear):
class MegatronRowParallelLinearTorch(RowParallelLinearTorch):
"""
MegatronRowParallelLinearTorch.
"""

def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
Expand Down Expand Up @@ -405,7 +502,7 @@ def __init__(
)


class ISPLinear(ColumnParallelLinear):
class ISPLinear(ColumnParallelLinearTorch):
"""
Linear class for isp tensor parallel mode.
"""
Expand Down
7 changes: 5 additions & 2 deletions internlm/model/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
from torch import nn

from internlm.core.context import ParallelMode
Expand All @@ -24,7 +23,11 @@ def __init__(self, parallel_output=True, label_smoothing=0):
label_smoothing = 0
self.label_smoothing = label_smoothing

if parallel_output:
if gpc.config.model.use_flash_attn and parallel_output:
from flash_attn.losses.cross_entropy import (
CrossEntropyLoss as FlashCrossEntropyLoss,
)

self.loss_fn = FlashCrossEntropyLoss(
reduction="mean",
inplace_backward=True,
Expand Down
15 changes: 11 additions & 4 deletions internlm/model/metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, List, Optional

import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
from torch import nn
from torch_scatter import scatter

from internlm.core.context import ParallelMode
Expand Down Expand Up @@ -210,9 +210,16 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)

self.loss_fn = FlashCrossEntropyLoss(
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
)
if gpc.config.model.use_flash_attn:
from flash_attn.losses.cross_entropy import (
CrossEntropyLoss as FlashCrossEntropyLoss,
)

self.loss_fn = FlashCrossEntropyLoss(
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
)
else:
self.loss_fn = nn.CrossEntropyLoss(reduction="none")

def update(self, logits, labels, type_ids=None):
with torch.no_grad():
Expand Down
10 changes: 6 additions & 4 deletions internlm/model/modeling_internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from typing import Optional

import torch
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mlp import ParallelFusedMLP
from torch import nn

from internlm.core.context import ParallelMode
Expand Down Expand Up @@ -120,7 +118,7 @@ def __init__(
self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)

if use_swiglu:
if use_swiglu or not use_flash_attn:
mlp_cls = get_mlp_cls(self.tp_mode)
self.mlp = mlp_cls(
hidden_size,
Expand All @@ -132,6 +130,8 @@ def __init__(
dtype=dtype,
)
else:
from flash_attn.modules.mlp import ParallelFusedMLP

self.mlp = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
Expand Down Expand Up @@ -312,9 +312,11 @@ def __init__(
else:
head_cls = ScaleColumnParallelLinear
if first:
if embed_split_hidden:
if embed_split_hidden or not use_flash_attn:
self.embedding = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
else:
from flash_attn.modules.embedding import ParallelGPT2Embeddings

self.embedding = ParallelGPT2Embeddings(
embed_dim=hidden_size,
vocab_size=vocab_size,
Expand Down
50 changes: 32 additions & 18 deletions internlm/model/modeling_internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,6 @@
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn.modules.embedding import ParallelGPT2Embeddings
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
FlashSelfAttention,
SelfAttention,
_update_kv_cache,
)
from flash_attn.modules.mlp import ParallelFusedMLP
from flash_attn.ops.layer_norm import dropout_add_layer_norm
from torch import nn

from internlm.core.context import ParallelMode
Expand All @@ -37,7 +26,12 @@
get_linear_cls,
get_mlp_cls,
)
from internlm.model.multi_head_attention import DistributedAttention
from internlm.model.multi_head_attention import (
CrossAttention,
DistributedAttention,
SelfAttention,
_update_kv_cache,
)
from internlm.model.utils import (
gather_forward_split_backward,
split_forward_gather_backward,
Expand Down Expand Up @@ -157,6 +151,10 @@ def __init__(
**factory_kwargs,
)

if use_flash_attn:
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn.modules.mha import FlashCrossAttention, FlashSelfAttention

inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout)
Expand All @@ -168,7 +166,7 @@ def __init__(
self.inner_cross_attn_softmax_scale = softmax_scale
self.inner_cross_attn_dropout = dropout

self.attn = flash_attn_varlen_kvpacked_func
self.attn = flash_attn_varlen_kvpacked_func if use_flash_attn else SelfAttention
if self.tp_mode == "isp":
self.attn = DistributedAttention(self.attn, sequence_process_group=sequence_process_group)

Expand Down Expand Up @@ -254,15 +252,21 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:
else:
q = q.squeeze(1)
k = k.squeeze(1)
cu_seqlens = kwargs.get("cu_seqlens", None)
max_seqlen = kwargs.get("max_seqlen", None)
q = self.rotary_emb._single_forward(
q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
k = self.rotary_emb._single_forward(
k,
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
).unsqueeze(1)
else:
raise NotImplementedError(
Expand Down Expand Up @@ -300,6 +304,8 @@ def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint:

if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
assert self.use_flash_attn is True
from flash_attn import flash_attn_varlen_kvpacked_func

if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
attn_mask = inference_params.attention_mask[:, None, ...]
attn_mask = torch.logical_or(
Expand Down Expand Up @@ -425,8 +431,10 @@ def _packed_forward(self, x, inference_params=None, **kwargs):
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)

indexes = kwargs.pop("indexes")
q = self.rotary_emb._single_forward(q, indexes=indexes)
k = self.rotary_emb._single_forward(k, indexes=indexes)
cu_seqlens = kwargs.pop("cu_seqlens")
max_seqlen = kwargs.pop("max_seqlen")
q = self.rotary_emb._single_forward(q, indexes=indexes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
k = self.rotary_emb._single_forward(k, indexes=indexes, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)

if inference_params is None:
kv = torch.concat([k.unsqueeze(1), v.unsqueeze(1)], dim=1)
Expand Down Expand Up @@ -581,12 +589,14 @@ def __init__(
else:
self.attention_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
self.ffn_norm = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon)
if self.fused_dropout_add_ln:
if self.fused_dropout_add_ln and self.use_flash_attn:
from flash_attn.ops.layer_norm import dropout_add_layer_norm

assert dropout_add_layer_norm is not None, "dropout_add_ln is not installed"
assert isinstance(self.attention_norm, nn.LayerNorm) and isinstance(self.dropout1, nn.Dropout)

sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)
if use_swiglu:
if use_swiglu or not gpc.config.model.use_flash_attn:
ffn = get_mlp_cls(self.tp_mode)
self.feed_forward = ffn(
hidden_size,
Expand All @@ -598,6 +608,8 @@ def __init__(
dtype=dtype,
)
else:
from flash_attn.modules.mlp import ParallelFusedMLP

self.feed_forward = ParallelFusedMLP(
hidden_size,
int(hidden_size * mlp_ratio),
Expand Down Expand Up @@ -857,9 +869,11 @@ def __init__(
sequence_parallel = gpc.config.parallel.get("sequence_parallel", False)

if first:
if embed_split_hidden:
if embed_split_hidden or not gpc.config.model.use_flash_attn:
self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size)
else:
from flash_attn.modules.embedding import ParallelGPT2Embeddings

self.tok_embeddings = ParallelGPT2Embeddings(
embed_dim=hidden_size,
vocab_size=vocab_size,
Expand Down
Loading

0 comments on commit 0dcc0e9

Please sign in to comment.