Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU Context Parallel Mamba2 #664

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions mamba_ssm/modules/context_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from typing import Optional

import torch
from torch import nn, Tensor
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

def send_and_receive_(x, receive_buffer, send_to_rank, receive_from_rank, group):
assert send_to_rank is not None or receive_from_rank is not None
ops = []
if send_to_rank is not None:
ops.append(dist.P2POp(dist.isend, x, send_to_rank, group))
if receive_from_rank is not None:
ops.append(dist.P2POp(dist.irecv, receive_buffer, receive_from_rank, group))

reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
dist.barrier()

class ContextParallelMixerFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, padding=0, process_group=torch.distributed.group.WORLD):
#Prepends the last n_padding tokens from layer_n to layer_{n+1}
#These are mixed into subsequent tokens of layer n+1 by convolution, but their index is then discarded
# the convolution is causal, so the mixing only goes in one direction
rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group)
if world_size == 1:
return x

send_to_rank = rank + 1 if rank < world_size - 1 else None
receive_from_rank = rank - 1 if rank > 0 else None
#print('dist', rank, 'send',send_to_rank, 'recieve',receive_from_rank)
#_, pre_tokens = x.split(x.shape[1]-self.padding, dim=1)
pre_tokens = x[:,-padding:].contiguous()
#print('dist',rank,'padding',padding)
assert pre_tokens.shape[1] == padding
receive_buffer = torch.zeros_like(pre_tokens, requires_grad=True).contiguous() #TODO this isn't used by rank=0
send_and_receive_(pre_tokens, receive_buffer, send_to_rank, receive_from_rank, process_group)
if rank > 0:
x = F.pad(x, (0, 0, padding, 0), 'constant', 0)
x[:, :padding] = receive_buffer
#print('x', rank, x.shape)
ctx.padding=padding
ctx.process_group = process_group
return x

@staticmethod
def backward(ctx, grad_x):
"""
grad x is input with the padding tokens from the next layer
the input of forward is not padded, this gradient needs to be popped and transfered
to the previous layer...
"""
process_group = ctx.process_group
rank, world_size = dist.get_rank(process_group), dist.get_world_size(process_group)
padding = ctx.padding
#print('grad_x', rank, grad_x.shape)
if world_size == 1:
return grad_x, None
send_to_rank = rank -1 if rank > 0 else None
receive_from_rank = rank + 1 if rank < world_size - 1 else None
pre_tokens_grad = grad_x[:, :padding].contiguous()
if rank > 0:
grad_x_out = grad_x[:, padding:].contiguous()
else:
grad_x_out = grad_x.clone()
assert pre_tokens_grad.shape[1] == ctx.padding
receive_buffer = torch.zeros_like(pre_tokens_grad).contiguous() #TODO this isn't used by rank=0
send_and_receive_(pre_tokens_grad, receive_buffer, send_to_rank, receive_from_rank, process_group)
if rank < world_size -1:
grad_x_out[:, -padding:] += receive_buffer
return grad_x_out, None, None

class ContextParallelMixerLayer(nn.Module):
def __init__(self, padding=0, process_group=torch.distributed.group.WORLD):
super(ContextParallelMixerLayer, self).__init__()
self.padding = padding
self.process_group = process_group

def forward(self, x):
return ContextParallelMixerFn.apply(x, self.padding, self.process_group)
58 changes: 47 additions & 11 deletions mamba_ssm/modules/mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,24 @@
from mamba_ssm.distributed.distributed_utils import all_reduce, reduce_scatter

from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined
from mamba_ssm.ops.triton.ssd_combined_cp import mamba_split_conv1d_scan_combined as mamba_split_conv1d_scan_combined_cp
from mamba_ssm.ops.triton.ssd_combined import mamba_split_conv1d_scan_combined

from mamba_ssm.modules.context_parallel import ContextParallelMixerLayer


from huggingface_hub import PyTorchModelHubMixin


import torch.distributed as dist

# Context Parallel - split input sequence
# Going to want to shard this outside of Mamba2 class, so it can run over multiple layers...
# auto_shard_seq = not force_ring_reduce_off and self.auto_shard_seq and is_distributed()
# mask = None
# (u, _), batch_sizes, num_sharded_batches = sharded_batch_to_sharded_seq(u, mask, self.ring_seq_size)
# End Context Parallel

class Mamba2(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
Expand Down Expand Up @@ -61,6 +74,7 @@ def __init__(
layer_idx=None, # Absorb kwarg for general module
process_group=None,
sequence_parallel=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are entering the boolean trap here with this API design...

context_parallel=False,
device=None,
dtype=None,
):
Expand All @@ -73,13 +87,14 @@ def __init__(
self.expand = expand
self.process_group = process_group
self.sequence_parallel = sequence_parallel
self.world_size = 1 if process_group is None else process_group.size()
self.local_rank = 0 if process_group is None else process_group.rank()
#FIXME this is probably for sequence parallel For context parallel we just replace it later - we odn't want to divide innder dimensions by world size since we don't shard that - but we probably could...
self.world_size = 1 #if process_group is None else process_group.size()
self.local_rank = 0 #if process_group is None else process_group.rank()
self.d_inner = (self.expand * self.d_model) // self.world_size
assert self.d_inner * self.world_size == self.expand * self.d_model
self.headdim = headdim
self.d_ssm = self.d_inner if d_ssm is None else d_ssm // self.world_size
assert ngroups % self.world_size == 0
assert ngroups % self.world_size == 0
self.ngroups = ngroups // self.world_size
assert self.d_ssm % self.headdim == 0
self.nheads = self.d_ssm // self.headdim
Expand All @@ -91,12 +106,26 @@ def __init__(
self.chunk_size = chunk_size
self.use_mem_eff_path = use_mem_eff_path
self.layer_idx = layer_idx
self.context_parallel = context_parallel

assert not (self.context_parallel and self.sequence_parallel)
if self.context_parallel or self.sequence_parallel and not self.process_group:
#TODO clean up process group passes along with world size/local rank here so one source of truth
assert torch.distributed.is_initialized()
self.process_group = torch.distributed.group.WORLD
self.world_size = torch.distributed.get_world_size()
self.local_rank = torch.distributed.get_rank()

if self.context_parallel:
self.cpmixer = ContextParallelMixerLayer(padding=d_conv - 1, process_group=self.process_group)
else:
self.cpmixer = None

# Order: [z, x, B, C, dt]
d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
if self.process_group is None:
if not self.sequence_parallel:
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
else:
else: #FIXME not sure why this has sequence parallel flag, why would use ColumnParallel without sequence parallel?
self.in_proj = ColumnParallelLinear(self.d_model, d_in_proj * self.world_size, bias=bias,
process_group=self.process_group, sequence_parallel=self.sequence_parallel,
**factory_kwargs)
Expand Down Expand Up @@ -144,7 +173,7 @@ def __init__(
self.norm = RMSNormGated(self.d_ssm, eps=1e-5, norm_before_gate=self.norm_before_gate,
group_size=self.d_ssm // ngroups, **factory_kwargs)

if self.process_group is None:
if not self.sequence_parallel: #FIXME not sure why this has sequence parallel flag, why would use RowParallel without sequence parallel?
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
else:
self.out_proj = RowParallelLinear(self.d_inner * self.world_size, self.d_model, bias=bias,
Expand Down Expand Up @@ -175,14 +204,20 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
out, _, _ = self.step(u, conv_state, ssm_state)
return out

zxbcdt = self.in_proj(u) # (B, L, d_in_proj) or (B * L, d_in_proj)
if self.cpmixer: #Context parallel - transfer some tokens to mix in with the conv layer to the next GPU
u = self.cpmixer(u)
zxbcdt = self.in_proj(u)
#torch.save(zxbcdt,f'zxbcdt_{dist.get_rank() if dist.is_initialized() else 0}.pt')
#torch.save(self.norm.weight,f'norm_weight_{dist.get_rank() if dist.is_initialized() else 0}.pt')
if seqlen_og is not None:
zxbcdt = rearrange(zxbcdt, "(b l) d -> b l d", l=seqlen)
# If the model is loaded in fp16, without the .float() here, A might be -inf
A = -torch.exp(self.A_log.float()) # (nheads) or (d_inner, d_state)
dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit)

if self.use_mem_eff_path and inference_params is None:
out = mamba_split_conv1d_scan_combined(
fn = mamba_split_conv1d_scan_combined_cp if self.context_parallel else mamba_split_conv1d_scan_combined
out = fn(
zxbcdt,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
Expand All @@ -199,13 +234,14 @@ def forward(self, u, seqlen=None, seq_idx=None, cu_seqlens=None, inference_param
headdim=None if self.D_has_hdim else self.headdim,
ngroups=self.ngroups,
norm_before_gate=self.norm_before_gate,
process_group=self.process_group,
**dt_limit_kwargs,
)
if seqlen_og is not None:
out = rearrange(out, "b l d -> (b l) d")
if self.process_group is not None:
reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
out = reduce_fn(out, self.process_group)
#if self.process_group is not None: #FIXME this was here before adding context parallel
# reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce
# out = reduce_fn(out, self.process_group)
else:
d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
z0, x0, z, xBC, dt = torch.split(
Expand Down
Loading