diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 0b19ea3..5c79ac6 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -2067,7 +2067,7 @@ def training( estimated_fwd_prefetch_memory_per_gpu = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer estimated_bwd_prefetch_memory_per_gpu = ( - 3 + int(fwd_prefetch) + + int(fwd_prefetch) + int(bwd_prefetch)) * (unsharded_weight_memory_per_layer) estimated_prefetch_memory_per_gpu = max(