diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index c5b7b57..92ac30d 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -692,8 +692,6 @@ def get_activation_memory_per_layer_mlp( bytes_per_gelu_input = mlp_activation_quant_bits / BITS_PER_BYTE bytes_per_2linear_input = mlp_activation_quant_bits / BITS_PER_BYTE - num_experts_per_gpu = self.model_config.moe_num_experts / ep_size - if is_inference: return max( bytes_per_1linear_input, @@ -1952,6 +1950,8 @@ def training( activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE, ds_zero: DSZeRO = DSZeRO.NONE, + fwd_prefetch: bool = True, + bwd_prefetch: bool = True, layernorm_dtype_bytes: int = BYTES_FP32, master_weights_dtype_bytes: int = BYTES_FP32, other_op_bytes: int = None, @@ -2047,8 +2047,17 @@ def training( self.weight_grad_op_state_memory_per_gpu = ( weight_memory_per_gpu + optimizer_state_memory_per_gpu + gradient_memory_per_gpu) + + estimated_fwd_prefetch_memory_usage = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer + + estimated_bwd_prefetch_memory_usage = ( + 3 + int(fwd_prefetch) + + int(bwd_prefetch)) * (unsharded_weight_memory_per_layer) + memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 - - self.weight_grad_op_state_memory_per_gpu) + self.weight_grad_op_state_memory_per_gpu - + max(estimated_fwd_prefetch_memory_usage, + estimated_bwd_prefetch_memory_usage)) logger.info( f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B" @@ -2057,37 +2066,45 @@ def training( " optimizer_state_memory_per_gpu:" f" {_num_to_string(optimizer_state_memory_per_gpu)}B," " gradient_memory_per_gpu:" - f" {_num_to_string(gradient_memory_per_gpu)}B, memory_left:" - f" {_num_to_string(memory_left)}B") + f" {_num_to_string(gradient_memory_per_gpu)}B", + " estimated_fwd_prefetch_memory_usage:" + f" {_num_to_string(estimated_fwd_prefetch_memory_usage)}B", + " estimated_bwd_prefetch_memory_usage:" + f" {_num_to_string(estimated_bwd_prefetch_memory_usage)}B", + " memory_left:" + f" {_num_to_string(memory_left)}B", + ) if memory_left < 0: logger.warning( - "model weight/optimizer stage/gradient is too large (requiring" - f" {_num_to_string(weight_memory_per_gpu)}B /" - f" {_num_to_string(optimizer_state_memory_per_gpu)}B /" - f" {_num_to_string(gradient_memory_per_gpu)}B) to fit in total GPU" - " memory") + "model weight/optimizer state/gradient or fwd/bwd prefetch memory usage is too large to fit in GPU memory" + ) # With pipeline parallelism, each stage contains L/p layers so the first stage must store p ×L/p = L layers worth of activations regardless of the pipeline parallel size p; activation memory required for the input embeddings, the last layer-norm, and the output layer are ignored here. Refer to https://arxiv.org/abs/2205.05198 for more details. - activation_memory_batch_size_1, activation_memory_attn_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1 = [ + activation_memory_per_layer_batch_size_1, attn_activation_memory_per_layer_batch_size_1, mlp_activation_memory_per_layer_batch_size_1, layernorm_activation_memory_per_layer_batch_size_1 = self.get_activation_memory_per_layer( + 1, + seq_len, + is_inference=False, + activation_recomputation=activation_recomputation, + layernorm_dtype_bytes=layernorm_dtype_bytes, + flash_attn=flash_attn, + softmax_dropout=softmax_dropout, + mlp_activation_quant_bits=mlp_activation_quant_bits, + mlp_1linear_quant_bits=mlp_1linear_quant_bits, + mlp_gelu_input_quant_bits=mlp_gelu_input_quant_bits, + mlp_2linear_quant_bits=mlp_2linear_quant_bits, + mlp_recompute_gelu=mlp_recompute_gelu, + return_breakdown=True, + ) + activation_memory_batch_size_1, attn_activation_memory_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1 = [ x * self.model_config.num_layers - for x in self.get_activation_memory_per_layer( - 1, - seq_len, - is_inference=False, - activation_recomputation=activation_recomputation, - layernorm_dtype_bytes=layernorm_dtype_bytes, - flash_attn=flash_attn, - softmax_dropout=softmax_dropout, - mlp_activation_quant_bits=mlp_activation_quant_bits, - mlp_1linear_quant_bits=mlp_1linear_quant_bits, - mlp_gelu_input_quant_bits=mlp_gelu_input_quant_bits, - mlp_2linear_quant_bits=mlp_2linear_quant_bits, - mlp_recompute_gelu=mlp_recompute_gelu, - return_breakdown=True, - ) + for x in (activation_memory_per_layer_batch_size_1, + attn_activation_memory_per_layer_batch_size_1, + mlp_activation_memory_per_layer_batch_size_1, + layernorm_activation_memory_per_layer_batch_size_1) ] + activation_memory_embedding_output_batch_size_1 = self.get_activation_memory_output_embedding( 1, seq_len) logger.info( @@ -2124,7 +2141,7 @@ def training( ) if batch_size_per_gpu == 1: - activation_memory_per_gpu, activation_memory_attn_per_gpu, activation_memory_mlp_per_gpu, activation_memory_layernorm_per_gpu = activation_memory_batch_size_1, activation_memory_attn_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1 + activation_memory_per_gpu, activation_memory_attn_per_gpu, activation_memory_mlp_per_gpu, activation_memory_layernorm_per_gpu = activation_memory_batch_size_1, attn_activation_memory_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1 else: activation_memory_per_gpu, activation_memory_attn_per_gpu, activation_memory_mlp_per_gpu, activation_memory_layernorm_per_gpu = [ x * self.model_config.num_layers @@ -2383,15 +2400,12 @@ def training( "(weight+op_state+act)_memory_per_gpu": optimizer_state_memory_per_gpu + weight_memory_per_gpu + activation_memory_per_gpu, - "end_fwd_memory_per_gpu": + "estimated_peak_fwd_memory_per_gpu": optimizer_state_memory_per_gpu + weight_memory_per_gpu + - activation_memory_per_gpu + unsharded_weight_memory_embedding + - unsharded_weight_memory_per_layer, - "end_last3_bwd_memory_per_gpu": + activation_memory_per_gpu + estimated_fwd_prefetch_memory_usage, + "estimated_peak_bwd_memory_per_gpu": optimizer_state_memory_per_gpu + weight_memory_per_gpu + - activation_memory_per_gpu + 3 * - (unsharded_weight_memory_per_layer + - unsharded_weight_memory_mlp_per_layer), + activation_memory_per_gpu + estimated_bwd_prefetch_memory_usage, "memory_left_per_gpu": memory_left, "latency_per_micro_batch": @@ -2536,6 +2550,8 @@ def train( total_num_tokens: int = None, activation_recomputation: int = 0, ds_zero: int = 0, + fwd_prefetch: bool = True, + bwd_prefetch: bool = True, dp_size: int = None, tp_size: int = 1, pp_size: int = 1, @@ -2658,6 +2674,8 @@ def train( activation_recomputation=ActivationRecomputation( activation_recomputation), ds_zero=DSZeRO(ds_zero), + fwd_prefetch=fwd_prefetch, + bwd_prefetch=bwd_prefetch, layernorm_dtype_bytes=layernorm_dtype_bytes, master_weights_dtype_bytes=master_weights_dtype_bytes, other_op_bytes=other_op_bytes,