Skip to content

Commit

Permalink
remove dependency of flash_attn when use_flash_attn is set to false
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Jan 23, 2024
1 parent cba90e6 commit 9bbcdc4
Show file tree
Hide file tree
Showing 11 changed files with 558 additions and 89 deletions.
181 changes: 166 additions & 15 deletions internlm/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,9 @@

from typing import Tuple

import rotary_emb
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
from torch import Tensor, nn

from internlm.core.context import ParallelMode
Expand Down Expand Up @@ -63,6 +60,22 @@ def forward(self, input_: Tensor) -> Tensor:
return output


def apply_rotary_torch(x1, x2, cos, sin, conj):
assert x1.device == x2.device == cos.device == sin.device, "All inputs must be on the same device"
assert x1.dtype == x2.dtype == cos.dtype == sin.dtype, "All inputs must have the same dtype"
assert x1.size() == x2.size(), "Input x1 and x2 must have the same sizes"
assert cos.size() == sin.size(), "Input cos and sin must have the same sizes"

if conj:
out1 = x1 * cos + x2 * sin
out2 = -x1 * sin + x2 * cos
else:
out1 = x1 * cos - x2 * sin
out2 = x1 * sin + x2 * cos

return out1, out2


class ApplyRotaryEmbQKV_(torch.autograd.Function):
"""
ApplyRotaryEmbQKV_
Expand All @@ -86,11 +99,17 @@ def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q1, q2 = qkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
if gpc.config.model.use_flash_attn:
rotary_emb.apply_rotary(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), q1, q2, False)
else:
q1, q2 = apply_rotary_torch(q1, q2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), False)
k1, k2 = qkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
)
if gpc.config.model.use_flash_attn:
rotary_emb.apply_rotary(
k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), k1, k2, False
)
else:
k1, k2 = apply_rotary_torch(k1, k2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), False)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
return qkv

Expand All @@ -100,19 +119,136 @@ def backward(ctx, dqkv):
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq1, dq2 = dqkv[:, 0, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
)
if gpc.config.model.use_flash_attn:
rotary_emb.apply_rotary(
dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), dq1, dq2, True
)
else:
dq1, dq2 = apply_rotary_torch(dq1, dq2, rearrange(cos, "s d -> s 1 d"), rearrange(sin, "s d -> s 1 d"), True)
dk1, dk2 = dqkv[:, 1, :, :rotary_dim].chunk(2, dim=-1)
rotary_emb.apply_rotary(
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
)
if gpc.config.model.use_flash_attn:
rotary_emb.apply_rotary(
dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), dk1, dk2, True
)
else:
dk1, dk2 = apply_rotary_torch(dk1, dk2, rearrange(cos_k, "s d -> s 1 d"), rearrange(sin_k, "s d -> s 1 d"), True)
return dqkv, None, None, None, None


class TorchApplyRotaryEmb(torch.autograd.Function):

@staticmethod
def forward(ctx, x, cos, sin, interleaved=False, inplace=False):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch, seqlen, nheads, headdim = x.shape
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
x_ro = x[..., :rotary_dim]
x1, x2 = x_ro.chunk(2, dim=-1) if not interleaved else (x_ro[..., ::2], x_ro[..., 1::2])
out = torch.empty_like(x) if not inplace else x
out_ro = out[..., :rotary_dim]
if inplace:
o1, o2 = x1, x2
else:
o1, o2 = (out_ro.chunk(2, dim=-1) if not interleaved
else (out_ro[..., ::2], out_ro[..., 1::2]))
o1, o2 = apply_rotary_torch(x1, x2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), False)
if not inplace and rotary_dim < headdim:
out[..., rotary_dim:].copy_(x[..., rotary_dim:])
ctx.save_for_backward(cos, sin)
ctx.interleaved = interleaved
ctx.inplace = inplace
return out if not inplace else x

@staticmethod
def backward(ctx, do):
cos, sin = ctx.saved_tensors
_, seqlen, _, headdim = do.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
inplace = ctx.inplace
do_ro = do[..., :rotary_dim]
do1, do2 = (do_ro.chunk(2, dim=-1) if not ctx.interleaved
else (do_ro[..., ::2], do_ro[..., 1::2]))
dx = torch.empty_like(do) if not inplace else do
if inplace:
dx1, dx2 = do1, do2
else:
dx_ro = dx[..., :rotary_dim]
dx1, dx2 = (dx_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dx_ro[..., ::2], dx_ro[..., 1::2]))
dx1, dx2 = apply_rotary_torch(do1, do2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), True)
if not inplace and rotary_dim < headdim:
dx[..., rotary_dim:].copy_(do[..., rotary_dim:])
return dx, None, None, None, None


class TorchApplyRotaryEmbQKV_(torch.autograd.Function):

@staticmethod
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None, interleaved=False):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
batch, seqlen, three, nheads, headdim = qkv.shape
assert three == 3
rotary_seqlen, rotary_dim = cos.shape
rotary_dim *= 2
assert rotary_dim <= headdim
assert seqlen <= rotary_seqlen
cos_k = cos if cos_k is None else cos_k
sin_k = sin if sin_k is None else sin_k
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
q_ro = qkv[:, :, 0, :, :rotary_dim]
q1, q2 = q_ro.chunk(2, dim=-1) if not interleaved else (q_ro[..., ::2], q_ro[..., 1::2])
q1, q2 = apply_rotary_torch(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), False)
k_ro = qkv[:, :, 1, :, :rotary_dim]
k1, k2 = k_ro.chunk(2, dim=-1) if not interleaved else (k_ro[..., ::2], k_ro[..., 1::2])
k1, k2 = apply_rotary_torch(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), False)
ctx.save_for_backward(cos, sin, cos_k, sin_k)
ctx.interleaved = interleaved
return qkv

@staticmethod
def backward(ctx, dqkv):
cos, sin, cos_k, sin_k = ctx.saved_tensors
_, seqlen, _, _, headdim = dqkv.shape
rotary_dim = cos.shape[-1]
rotary_dim *= 2
dq_ro = dqkv[:, :, 0, :, :rotary_dim]
dq1, dq2 = (dq_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dq_ro[..., ::2], dq_ro[..., 1::2]))
dq1, dq2 = apply_rotary_torch(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
rearrange(sin[:seqlen], 's d -> s 1 d'), True)
dk_ro = dqkv[:, :, 1, :, :rotary_dim]
dk1, dk2 = (dk_ro.chunk(2, dim=-1) if not ctx.interleaved
else (dk_ro[..., ::2], dk_ro[..., 1::2]))
dk1, dk2 = apply_rotary_torch(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
rearrange(sin_k[:seqlen], 's d -> s 1 d'), True)
return dqkv, None, None, None, None, None


apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply


class RotaryEmbedding(torch.nn.Module):
Expand Down Expand Up @@ -202,12 +338,25 @@ def _forward(self, qkv: torch.Tensor, indexes=0) -> Tuple[torch.Tensor, torch.Te
self._sin_k_cached[indexes],
)

def _get_legacy_apply_rotary_functions(self):
if gpc.config.model.use_flash_attn:
import rotary_emb
from flash_attn.layers.rotary import ApplyRotaryEmb as LegacyApplyRotaryEmb
from flash_attn.layers.rotary import ApplyRotaryEmbQKV_ as LegacyApplyRotaryEmbQKV_
legacy_apply_rotary_embed_qkv = LegacyApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = LegacyApplyRotaryEmb.apply
else:
legacy_apply_rotary_embed_qkv = TorchApplyRotaryEmbQKV_.apply
legacy_apply_rotary_embed = TorchApplyRotaryEmb.apply
return legacy_apply_rotary_embed_qkv, legacy_apply_rotary_embed

def _eval_forward(self, qkv, seqlen_offset=0):
"""
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self._update_cos_sin_cache(qkv, seqlen_offset + qkv.shape[1])
legacy_apply_rotary_embed_qkv, _ = self._get_legacy_apply_rotary_functions()
if self.scale is None:
return legacy_apply_rotary_embed_qkv(
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
Expand All @@ -225,12 +374,14 @@ def _single_forward(self, x, indexes=0):
assert self.scale is None
self._update_cos_sin_cache(x, indexes)
x = x[None, ...]
_, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions()
ret = legacy_apply_rotary_embed(x, self._cos_cached[indexes], self._sin_cached[indexes]).squeeze(0)
return ret

def _single_eval_forward(self, x, seqlen_offset=0):
assert self.scale is None
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
_, legacy_apply_rotary_embed = self._get_legacy_apply_rotary_functions()
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])


Expand Down
67 changes: 62 additions & 5 deletions internlm/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
from typing import Optional

import torch
from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
from flash_attn.utils.distributed import all_reduce, reduce_scatter
from torch import nn
from torch.distributed import ProcessGroup

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.model.utils import Silu, fused_dense_func_torch
from internlm.model.utils import Silu, fused_dense_func_torch, all_reduce, reduce_scatter


class ScaleColumnParallelLinear(nn.Linear):
Expand Down Expand Up @@ -114,7 +113,33 @@ def forward(self, input): # pylint: disable=W0622
)


class ColumnParallelLinearTorch(ColumnParallelLinear):
class ColumnParallelLinearTorch(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
if out_features % multiple_of:
raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}")
multiple = out_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
super().__init__(
in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

def forward(self, x):
# If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism:
# we do an all_gather of x before doing the matmul.
Expand All @@ -125,7 +150,39 @@ def forward(self, x):
)


class RowParallelLinearTorch(RowParallelLinear):
class RowParallelLinearTorch(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
process_group: ProcessGroup,
bias: bool = True,
sequence_parallel=True,
multiple_of=1,
device=None,
dtype=None,
) -> None:
world_size = torch.distributed.get_world_size(process_group)
rank = torch.distributed.get_rank(process_group)
if in_features % multiple_of:
raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}")
multiple = in_features // multiple_of
# We want to split @multiple across world_size, but it could be an uneven split
div = multiple // world_size
mod = multiple % world_size
# The first @mod ranks get @div + 1 copies, the rest get @div copies
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias
super().__init__(
local_multiple * multiple_of,
out_features,
bias=bias and rank == 0,
device=device,
dtype=dtype,
)
self.process_group = process_group
self.sequence_parallel = sequence_parallel

def forward(self, x):
"""
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
Expand Down
4 changes: 2 additions & 2 deletions internlm/model/loss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
from torch import nn

from internlm.core.context import ParallelMode
Expand All @@ -24,7 +23,8 @@ def __init__(self, parallel_output=True, label_smoothing=0):
label_smoothing = 0
self.label_smoothing = label_smoothing

if parallel_output:
if gpc.config.model.use_flash_attn and parallel_output:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
self.loss_fn = FlashCrossEntropyLoss(
reduction="mean",
inplace_backward=True,
Expand Down
13 changes: 9 additions & 4 deletions internlm/model/metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List

import torch
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss

from torch_scatter import scatter
from torch import nn

from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
Expand Down Expand Up @@ -208,9 +209,13 @@ def __init__(self, device, dp_pg, dataset_types: List[str] = None) -> None:
self.ds_loss = torch.zeros(self.total_type_count, dtype=torch.float, device=device)
self.ds_token_num = torch.zeros(self.total_type_count, dtype=torch.float, device=device)

self.loss_fn = FlashCrossEntropyLoss(
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
)
if gpc.config.model.use_flash_attn:
from flash_attn.losses.cross_entropy import CrossEntropyLoss as FlashCrossEntropyLoss
self.loss_fn = FlashCrossEntropyLoss(
reduction="none", inplace_backward=True, process_group=gpc.get_group(ParallelMode.TENSOR)
)
else:
self.loss_fn = nn.CrossEntropyLoss(reduction="none")

def update(self, logits, labels, type_ids=None):
with torch.no_grad():
Expand Down
Loading

0 comments on commit 9bbcdc4

Please sign in to comment.