Skip to content

Commit

Permalink
remove dependency of flash_attn when use_flash_attn is set to false
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Jan 24, 2024
1 parent cba90e6 commit 9b67f33
Show file tree
Hide file tree
Showing 11 changed files with 625 additions and 86 deletions.
183 changes: 168 additions & 15 deletions internlm/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

from typing import Tuple

import rotary_emb
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
from torch import Tensor, nn

from internlm.core.context import ParallelMode
Expand Down Expand Up @@ -63,6 +60,22 @@ def forward(self, input_: Tensor) -> Tensor:
return output


def apply_rotary_torch(x1, x2, cos, sin, conj):
assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device"
assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype"
assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes"
assert cos.size() == sin.size(), "Input cos and sin must have the same sizes"

if conj:
out1 = x1 * cos + x2 * sin
out2 = -x1 * sin + x2 * cos
else:
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos

return out1, out2


class ApplyRotaryEmbQKV_(torch.autograd.Function):
"""
ApplyRotaryEmbQKV_
Expand All @@ -86,11 +99,23 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
if gpc.config.model.use_flash_attn:
import rotary_emb

rotary_emb.apply_rotary(
q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False
)
else:
q1, q2 = apply_rotary_torch(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), False)
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
)
if gpc.config.model.use_flash_attn:
rotary_emb.apply_rotary(
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
)
else:
k1, k2 = apply_rotary_torch(
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), False
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
return qkv

Expand All @@ -100,19 +125,130 @@ def backward(ctx, dqkv):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
)
if gpc.config.model.use_flash_attn:
import rotary_emb

rotary_emb.apply_rotary(
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
)
else:
dq1, dq2 = apply_rotary_torch(
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), True
)
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
)
if gpc.config.model.use_flash_attn:
rotary_emb.apply_rotary(
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
)
else:
dk1, dk2 = apply_rotary_torch(
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), True
)
return dqkv, None, None, None, None


class TorchApplyRotaryEmb(torch.autograd.Function):
"""
TorchApplyRotaryEmb
"""

@staticmethod
def forward(ctx, x, cos, sin, interleaved=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
_, seqlen, _, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
x1, x2 = apply_rotary_torch(
x1, x2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False
)
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
return x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, _ = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
do_ro = do[..., :rotary_dim]
do1, do2 = do_ro.chunk(2, dim=-1) if not ctx.interleaved else (do_ro[..., ::2], do_ro[..., 1::2])
do1, do2 = apply_rotary_torch(
do1, do2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True
)
return do, None, None, None, None


class TorchApplyRotaryEmbQKV_(torch.autograd.Function):
"""
TorchApplyRotaryEmbQKV_
"""

@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
"""
_, seqlen, three, _, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
q1, q2 = apply_rotary_torch(
q1, q2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), False
)
k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
k1, k2 = apply_rotary_torch(
k1, k2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), False
)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
_, seqlen, _, _, _ = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = dq_ro.chunk(2, dim=-1) if not ctx.interleaved else (dq_ro[..., ::2], dq_ro[..., 1::2])
dq1, dq2 = apply_rotary_torch(
dq1, dq2, rearrange(cos[:seqlen], "s d -> s 1 d"), rearrange(sin[:seqlen], "s d -> s 1 d"), True
)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = dk_ro.chunk(2, dim=-1) if not ctx.interleaved else (dk_ro[..., ::2], dk_ro[..., 1::2])
dk1, dk2 = apply_rotary_torch(
dk1, dk2, rearrange(cos_k[:seqlen], "s d -> s 1 d"), rearrange(sin_k[:seqlen], "s d -> s 1 d"), True
)
return dqkv, None, None, None, None, None


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply


class RotaryEmbedding(torch.nn.Module):
Expand Down Expand Up @@ -202,12 +338,27 @@ def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Te
self._sin_k_cached[indexes],
)

def _get_legacy_apply_rotary_functions(self):
if gpc.config.model.use_flash_attn:
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
from flash_attn.layers.rotary import (
ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_,
)

legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
else:
legacy_apply_rotary_embed_qkv = TorchApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = TorchApplyRotaryEmb.apply
return legacy_apply_rotary_embed_qkv, legacy_apply_rotary_embed

def _eval_forward(self, qkv, seqlen_offset=0):
"""
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
legacy_apply_rotary_embed_qkv, _ = self._get_legacy_apply_rotary_functions()
if self.scale is None:
return legacy_apply_rotary_embed_qkv(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
Expand All @@ -225,12 +376,14 @@ def _single_forward(self, x, indexes=0):
assert self.scale is None
self._update_cos_sin_cache(x, indexes)
x = x[None, ...]
_, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions()
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
return ret

def _single_eval_forward(self, x, seqlen_offset=0):
assert self.scale is None
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
_, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions()
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])


Expand Down
102 changes: 97 additions & 5 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from typing import Optional

import torch
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from torch.distributed import ProcessGroup

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fused_dense_func_torch
from internlm.model.utils import (
Silu,
all_reduce,
fused_dense_func_torch,
reduce_scatter,
)


class ScaleColumnParallelLinear(nn.Linear):
Expand Down Expand Up @@ -114,7 +118,47 @@ def forward(self, input): # pylint: disable=W0622
)


class ColumnParallelLinearTorch(ColumnParallelLinear):
class ColumnParallelLinearTorch(nn.Linear):
"""
ColumnParallelLinearTorch.
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
"""

def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % multiple_of:
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
multiple = out_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
Expand All @@ -125,7 +169,55 @@ def forward(self, x):
)


class RowParallelLinearTorch(RowParallelLinear):
class RowParallelLinearTorch(nn.Linear):
"""
RowParallelLinearTorch.
Args:
in_features (int): size of each input sample
out_features (int): size of each output sample
process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
in the config.
sequence_parallel (bool): If sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
we do an all_gather of x before doing the matmul.
If not, then the input is already gathered.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
weight_scale (int): For training stability. 1 by default.
"""

def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % multiple_of:
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
multiple = in_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias
super().__init__(
local_multiple * multiple_of,
out_features,
bias=bias and rank == 0,
device=device,
dtype=dtype,
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
Expand Down
7 changes: 5 additions & 2 deletions internlm/model/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
from torch import nn

from internlm.core.context import ParallelMode
Expand All @@ -24,7 +23,11 @@ def __init__(self, parallel_output=True, label_smoothing=0):
label_smoothing = 0
self.label_smoothing = label_smoothing

if parallel_output:
if gpc.config.model.use_flash_attn and parallel_output:
from flash_attn.losses.cross_entropy import (
CrossEntropyLoss as FlashCrossEntropyLoss,
)

self.loss_fn = FlashCrossEntropyLoss(
reduction="mean",
inplace_backward=True,
Expand Down
Loading

0 comments on commit 9b67f33

Please sign in to comment.