From fd6ad18eb19f1b934fcdb9049ab162925f7e9831 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 22 May 2024 09:53:47 -0700 Subject: [PATCH] update bwd prefetch memory calc --- llm_analysis/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 0b19ea3..5c79ac6 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -2067,7 +2067,7 @@ def training( estimated_fwd_prefetch_memory_per_gpu = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer estimated_bwd_prefetch_memory_per_gpu = ( - 3 + int(fwd_prefetch) + + int(fwd_prefetch) + int(bwd_prefetch)) * (unsharded_weight_memory_per_layer) estimated_prefetch_memory_per_gpu = max(