Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-GPU Context Parallel Mamba2 #664

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

josiahbjorgaard
Copy link

@josiahbjorgaard josiahbjorgaard commented Jan 10, 2025

I've made an implementation here of Context Parallelism for Mamba 2. It uses a sequential step at the state transfer stage, but otherwise functions in parallel. I've validate that the results are numerically within floating point error between single GPU context and multi-GPU Context, for both forward and backward pass calculations.

It uses a hack of the causal_conv1d function by transfering the number of tokens equivalent to the convolution window between GPUs and then discarding the result for few prepended tokens on each GPU. This requires a new ContextMixer layer to be inserted before each Mamba2 Layer, which is automatically inserted in a modification to the Mamba 2 class. The actual GPU to GPU transfer is done in a loop in the ssd_combined function.

Please let me know how I can further improve the PR to make it a mergeable contribution. Also feel free to reach out if you'd like help setting up a multi-GPU context parallel run.

N.B. this PR does not include splitting of the initial input sequence or aggregating gradients after loss, both of which would need to be performed by the training loop code.

@@ -61,6 +74,7 @@ def __init__(
layer_idx=None, # Absorb kwarg for general module
process_group=None,
sequence_parallel=True,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are entering the boolean trap here with this API design...

@ZYHowell
Copy link

It looks like you are using the process_group for context parallel, while it was designed for the tensor and sequence parallel. As long as TP and CP are compatible, I believe you can simply introduce a new cp_process_group to remove most lines you commented out. Do you have a plan on developing this part?

@josiahbjorgaard
Copy link
Author

josiahbjorgaard commented Feb 1, 2025

@ZYHowell I think TP and CP should be compatible and I had thought about separating the process groups in this way. Are you interested in implementing that?

@Skylion007 any suggestions on how best to approach the logical selection of using the context parallel path and I also believe the configuration options are not ideal. Perhaps supplying the cp_process_group can provide the boolean flag to select the cp path? It is unfortunately requiring a large amount of repeated code between the standard and context parallel versions of the ssd code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants