Skip to content

Commit

Permalink
try to solve fsdp bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed May 2, 2024
1 parent 280cb6c commit 673237b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
6 changes: 3 additions & 3 deletions examples/config_tiny_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ optimizer:
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 8
lr_decay_steps: 13
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
Expand Down Expand Up @@ -104,6 +104,6 @@ tokens:
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 32
train_steps: 10
sequence_length: 256
train_steps: 15
val_check_interval: -1
3 changes: 3 additions & 0 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@

if DISABLE_FLASH_ATTENTION:
print("Warning: Flash attention was disabled!")
# FSDP
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_flash_sdp(False)


RMSNorm = RMSNorm if DISABLE_FLASH_ATTENTION else TritonRMSNorm
Expand Down
2 changes: 1 addition & 1 deletion tests/test_llama.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Script to test correctness of training script by comparing loss value after 100th iteration with expected loss value
# pytest -sv tests/test_train_llama.py or python tests/test_train_llama.py
# pytest -sv tests/test_llama.py or python tests/test_train_llama.py

import atexit
import os
Expand Down

0 comments on commit 673237b

Please sign in to comment.