diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 92ac30d..f887865 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -2048,16 +2048,16 @@ def training( weight_memory_per_gpu + optimizer_state_memory_per_gpu + gradient_memory_per_gpu) - estimated_fwd_prefetch_memory_usage = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer + estimated_fwd_prefetch_memory_per_gpu = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer - estimated_bwd_prefetch_memory_usage = ( + estimated_bwd_prefetch_memory_per_gpu = ( 3 + int(fwd_prefetch) + int(bwd_prefetch)) * (unsharded_weight_memory_per_layer) memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 - self.weight_grad_op_state_memory_per_gpu - - max(estimated_fwd_prefetch_memory_usage, - estimated_bwd_prefetch_memory_usage)) + max(estimated_fwd_prefetch_memory_per_gpu, + estimated_bwd_prefetch_memory_per_gpu)) logger.info( f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B" @@ -2067,10 +2067,10 @@ def training( f" {_num_to_string(optimizer_state_memory_per_gpu)}B," " gradient_memory_per_gpu:" f" {_num_to_string(gradient_memory_per_gpu)}B", - " estimated_fwd_prefetch_memory_usage:" - f" {_num_to_string(estimated_fwd_prefetch_memory_usage)}B", - " estimated_bwd_prefetch_memory_usage:" - f" {_num_to_string(estimated_bwd_prefetch_memory_usage)}B", + " estimated_fwd_prefetch_memory_per_gpu:" + f" {_num_to_string(estimated_fwd_prefetch_memory_per_gpu)}B", + " estimated_bwd_prefetch_memory_per_gpu:" + f" {_num_to_string(estimated_bwd_prefetch_memory_per_gpu)}B", " memory_left:" f" {_num_to_string(memory_left)}B", ) @@ -2400,12 +2400,16 @@ def training( "(weight+op_state+act)_memory_per_gpu": optimizer_state_memory_per_gpu + weight_memory_per_gpu + activation_memory_per_gpu, + "estimated_fwd_prefetch_memory_per_gpu": + estimated_fwd_prefetch_memory_per_gpu, + "estimated_bwd_prefetch_memory_per_gpu": + estimated_bwd_prefetch_memory_per_gpu, "estimated_peak_fwd_memory_per_gpu": optimizer_state_memory_per_gpu + weight_memory_per_gpu + - activation_memory_per_gpu + estimated_fwd_prefetch_memory_usage, + activation_memory_per_gpu + estimated_fwd_prefetch_memory_per_gpu, "estimated_peak_bwd_memory_per_gpu": optimizer_state_memory_per_gpu + weight_memory_per_gpu + - activation_memory_per_gpu + estimated_bwd_prefetch_memory_usage, + activation_memory_per_gpu + estimated_bwd_prefetch_memory_per_gpu, "memory_left_per_gpu": memory_left, "latency_per_micro_batch":