diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 1299b28..5354d3d 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -369,17 +369,21 @@ def get_memory_weight_per_layer( Returns: float: the memory (in bytes) required to store the weights of a transformer layer """ - memory_weight_per_layer = ( - ( - self.get_num_params_per_layer_attn() - + self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router() + self.get_num_params_per_layer_layernorm() - ) - * self.dtype_config.weight_bits - / BITS_PER_BYTE - / self.parallelism_config.tp_size - ) if ds_zero == DSZeRO.STAGE_3: - memory_weight_per_layer /= self.parallelism_config.dp_size + sharded_dp_size = self.parallelism_config.dp_size + else: + sharded_dp_size = 1 + + memory_weight_attn_per_layer = self.get_num_params_per_layer_attn() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size /sharded_dp_size + + memory_weight_mlp_per_layer = (self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router()) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size + + memory_weight_layernorm_per_layer = self.get_num_params_per_layer_layernorm() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size + + memory_weight_per_layer = memory_weight_attn_per_layer + memory_weight_mlp_per_layer + memory_weight_layernorm_per_layer + + logger.debug(f'memory_weight_attn_per_layer: {_num_to_string(memory_weight_attn_per_layer)}B, memory_weight_mlp_per_layer: {_num_to_string(memory_weight_mlp_per_layer)}B, memory_weight_layernorm_per_layer: {_num_to_string(memory_weight_layernorm_per_layer)}B') + return memory_weight_per_layer def get_memory_optimizer_state_per_layer( @@ -534,6 +538,9 @@ def get_memory_activation_per_layer_mlp( seq_len: int, is_inference: bool = True, activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE, + activation_quant_bits: int = None, + recompute_gelu: bool = False, + with_dropout: bool = False, ) -> float: """Get the memory (in bytes) required to store the activations of the MLP in a transformer layer, given the batch size, sequence length, and @@ -556,9 +563,12 @@ def get_memory_activation_per_layer_mlp( sp_size = self.parallelism_config.sp_size ep_size = self.parallelism_config.ep_size hidden_dim = self.model_config.hidden_dim + bytes_per_activation = ( self.dtype_config.activation_bits / BITS_PER_BYTE ) + if activation_quant_bits: + bytes_per_activation = activation_quant_bits/ BITS_PER_BYTE if is_inference: return (5 * seq_len * batch_size * hidden_dim / sp_size) * bytes_per_activation @@ -568,18 +578,20 @@ def get_memory_activation_per_layer_mlp( # dropout mask only requires a single byte per element drop_out_mask = ( seq_len * batch_size * hidden_dim / sp_size - ) + ) if with_dropout else 0 + print(f'XXXX recompute_gelu = {recompute_gelu}') if self.model_config.moe_num_experts == 1: memory_activation_per_layer_mlp = ( (1 * seq_len * batch_size * hidden_dim / sp_size) - + (2 * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio / tp_size) + + ((2 if not recompute_gelu else 1) * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio / tp_size) ) * bytes_per_activation + drop_out_mask else: memory_activation_per_layer_mlp = self.model_config.moe_top_k *( (1 * seq_len * batch_size * hidden_dim / sp_size) - + (2 * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio * self.model_config.moe_num_experts/ ep_size / tp_size) + + ((2 if not recompute_gelu else 1) * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio * self.model_config.moe_num_experts/ ep_size / tp_size) ) * bytes_per_activation + drop_out_mask + return memory_activation_per_layer_mlp def get_memory_activation_per_layer_layernorm( @@ -620,6 +632,8 @@ def get_memory_activation_per_layer( is_inference: bool = True, activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE, layernorm_dtype_bytes: int = BYTES_FP32, + mlp_activation_quant_bits: int = None, + mlp_recompute_gelu: bool = False, ) -> float: """Get the memory (in bytes) required to store the activations of a transformer layer, given the batch size, sequence length, and whether @@ -656,7 +670,7 @@ def get_memory_activation_per_layer( memory_activation_per_layer_mlp = ( self.get_memory_activation_per_layer_mlp( - batch_size, seq_len, is_inference, activation_recomputation + batch_size, seq_len, is_inference, activation_recomputation, activation_quant_bits=mlp_activation_quant_bits, recompute_gelu=mlp_recompute_gelu, ) ) @@ -669,17 +683,15 @@ def get_memory_activation_per_layer( ) ) - logger.debug( - "memory_activation_per_layer_attn:" - f" {_num_to_string(memory_activation_per_layer_attn)}B," - " memory_activation_per_layer_mlp:" - f" {_num_to_string(memory_activation_per_layer_mlp)}B," - " memory_activation_per_layer_layernorm:" - f" {_num_to_string(memory_activation_per_layer_layernorm)}B" - ) - if is_inference: memory_activation_per_layer = max(memory_activation_per_layer_attn, memory_activation_per_layer_mlp, memory_activation_per_layer_layernorm) + logger.debug( + f"memory_activation_per_layer for batch_size {batch_size}:" + f" {_num_to_string(memory_activation_per_layer)}B" + f" (max(attn, mlp, layernorm): max({_num_to_string(memory_activation_per_layer_attn)}B ," + f" {_num_to_string(memory_activation_per_layer_mlp)}B , 2 *" + f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B))" + ) else: memory_activation_per_layer = ( memory_activation_per_layer_attn @@ -687,9 +699,9 @@ def get_memory_activation_per_layer( + 2 * memory_activation_per_layer_layernorm ) logger.debug( - "memory_activation_per_layer:" + f"memory_activation_per_layer for batch_size {batch_size}:" f" {_num_to_string(memory_activation_per_layer)}B" - f" ({_num_to_string(memory_activation_per_layer_attn)}B +" + f" (attn + mlp + layernorm: {_num_to_string(memory_activation_per_layer_attn)}B +" f" {_num_to_string(memory_activation_per_layer_mlp)}B + 2 *" f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B)" ) @@ -1793,6 +1805,8 @@ def training( activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE, ds_zero: DSZeRO = DSZeRO.NONE, layernorm_dtype_bytes: int = BYTES_FP32, + mlp_activation_quant_bits: int = None, + mlp_recompute_gelu: bool = False, output_dir: str = None, output_file_suffix: str = "", ) -> dict: @@ -1893,6 +1907,8 @@ def training( is_inference=False, activation_recomputation=activation_recomputation, layernorm_dtype_bytes=layernorm_dtype_bytes, + mlp_activation_quant_bits=mlp_activation_quant_bits, + mlp_recompute_gelu=mlp_recompute_gelu, ) * self.model_config.num_layers ) @@ -1920,6 +1936,8 @@ def training( is_inference=False, activation_recomputation=activation_recomputation, layernorm_dtype_bytes=layernorm_dtype_bytes, + mlp_activation_quant_bits=mlp_activation_quant_bits, + mlp_recompute_gelu=mlp_recompute_gelu, ) * self.model_config.num_layers ) @@ -2215,6 +2233,8 @@ def train( ep_size: int = 1, total_num_gpus: int = None, layernorm_dtype_bytes: int = BYTES_FP32, + mlp_activation_quant_bits: int = None, + mlp_recompute_gelu: bool = False, achieved_tflops: float = None, flops_efficiency: float = None, hbm_memory_efficiency: float = HBM_MEMORY_EFFICIENCY, @@ -2309,6 +2329,8 @@ def train( ), ds_zero=DSZeRO(ds_zero), layernorm_dtype_bytes=layernorm_dtype_bytes, + mlp_activation_quant_bits=mlp_activation_quant_bits, + mlp_recompute_gelu=mlp_recompute_gelu, output_dir=output_dir, output_file_suffix=output_file_suffix, )