Skip to content

Commit

Permalink
update bwd prefetch memory calc
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed May 22, 2024
1 parent 2b92db3 commit fd6ad18
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2067,7 +2067,7 @@ def training(
estimated_fwd_prefetch_memory_per_gpu = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer

estimated_bwd_prefetch_memory_per_gpu = (
3 + int(fwd_prefetch) +
int(fwd_prefetch) +
int(bwd_prefetch)) * (unsharded_weight_memory_per_layer)

estimated_prefetch_memory_per_gpu = max(
Expand Down

0 comments on commit fd6ad18

Please sign in to comment.