Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mamba2 doesn't support Multi-GPU training #676

Open
NadavSc opened this issue Jan 18, 2025 · 1 comment
Open

Mamba2 doesn't support Multi-GPU training #676

NadavSc opened this issue Jan 18, 2025 · 1 comment

Comments

@NadavSc
Copy link

NadavSc commented Jan 18, 2025

Hi! I'm using SFTTrainer (inherited from Transformers Trainer) to fine-tune Mamba2.
When using cuda_kernels_forward in Mamba2 on multiple GPUs the following error appears (full traceback in the end):

config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
TypeError: 'NoneType' object is not a mapping ```

However, it works just fine when I'm using the slower path, torch_forward.
Do you know how to address this issue?
Thanks a lot.

Reproduction

  from datasets import load_dataset
  from trl import SFTTrainer, SFTConfig
  from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
  
  model_id = 'AntonV/mamba2-130m-hf'
  dataset_name = 'yelp_review_full'

  tokenizer = AutoTokenizer.from_pretrained(model_id)
  tokenizer.pad_token = tokenizer.eos_token
  tokenizer.padding_side = 'right'
  data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
  model = AutoModelForCausalLM.from_pretrained(model_id)

  dataset = load_dataset(dataset_name, split='train', streaming=True)
  train_dataset = dataset

  training_args = SFTConfig(
      output_dir='./outputs',
      num_train_epochs=5,
      per_device_train_batch_size=4,
      per_device_eval_batch_size=4,
      logging_dir='./logs',
      learning_rate=2e-3,
      save_steps=500,
      save_safetensors=False,
      max_steps=10000,
      report_to='none'
  )
  trainer = SFTTrainer(
      model=model,
      processing_class=tokenizer,
      data_collator=data_collator,
      args=training_args,
      train_dataset=train_dataset,
  )
  trainer.train()

Traceback

  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/train.py", line 82, in <module>
    main()
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/train.py", line 79, in main
    trainer.train()
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 2123, in train
    return inner_training_loop(
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/transformers/trainer.py", line 3612, in training_step
    self.accelerator.backward(loss, **kwargs)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/accelerate/accelerator.py", line 2248, in backward
    loss.backward(**kwargs)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/autograd/function.py", line 306, in apply
    return user_fn(self, *args)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/torch/amp/autocast_mode.py", line 501, in decorate_bwd
    return bwd(*args, **kwargs)
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 893, in backward
    dx, ddt, dA, dB, dC, dD, _, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 414, in _mamba_chunk_scan_combined_bwd
    dx, ddt, dD_from_x = _chunk_scan_chunk_state_bwd_dx(x, dt, dA_cumsum, B, CB, dout, dstates, D=D, seq_idx=seq_idx, dx=dx)
  File "/mnt/lbosm1/home/nadavsc/projects/LLMamba/mamba_ssm/ops/triton/ssd_combined.py", line 250, in _chunk_scan_chunk_state_bwd_dx
    _chunk_scan_chunk_state_bwd_dx_kernel[grid_dx](
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/home/nadavsc/LIGHTBITS/envs/ssm/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 170, in run
    config.pre_hook({**self.nargs, **kwargs, **config.all_kwargs()})
TypeError: 'NoneType' object is not a mapping

@vasqu
Copy link
Contributor

vasqu commented Jan 20, 2025

Maybe related to #84

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants