diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index 8d0c97d..7652f7b 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -36,21 +36,27 @@ from llm_analysis.utils import _latency_to_string, _num_to_string, within_range +@total_ordering class ActivationRecomputation(Enum): NONE = 0 """No activation recomputation; requires the most amount of memory.""" + ATTN_COMPUTE = 1 + """Selectively checkpoints the attention computation (QK^T matrix multiply, softmax, softmax dropout, and attention over + V.) in the attention module of a transformer layer; + this part takes up a considerable amount of memory but are not computationally expensive to + recompute""" + ATTN = 2 + """Selectively checkpoints the input to the attention module in a transformer layer; requires an extra forward pass on attention.""" + NORM_ATTN_NORM = 3 + """Selectively checkpoints the input to the sequence of modules (layernom-attention-layernom) in a transformer layer; requires an extra forward pass on (layernom-attention-layernom).""" + FULL = 4 + """Full activation recomputation stores the input to EVERY transformer layer; requires the least + amount of memory; requires an extra forward pass of the layer.""" - SELECTIVE = 1 - """Selectively checkpoints and recomputes only parts of each transformer layer that - take up a considerable amount of memory but are not computationally expensive to - recompute, i.e. QK^T matrix multiply, softmax, softmax dropout, and attention over - V.""" - - FULL = 2 - """Full activation recomputation stores the input to EVERY transformer layer, which - is sharded across the tensor parallel group, thus requiring an extra all-gather - (ignored for now) per layer and add communication overhead; requires the lease - amount of memory; requires an extra forward pass.""" + def __lt__(self, other): + if self.__class__ is other.__class__: + return self.value < other.value + return NotImplemented @total_ordering @@ -581,13 +587,18 @@ def get_activation_memory_per_layer_attn( bytes_per_activation = (self.dtype_config.activation_bits / BITS_PER_BYTE) - if (not is_inference - ) and activation_recomputation == ActivationRecomputation.FULL: + if is_inference: + assert activation_recomputation == ActivationRecomputation.NONE, f'Inference does not need activation recomputation, but got activation_recomputation = {activation_recomputation}' + + if activation_recomputation >= ActivationRecomputation.NORM_ATTN_NORM: + return 0 + elif activation_recomputation == ActivationRecomputation.ATTN: return (seq_len * batch_size * hidden_dim * bytes_per_activation / sp_size) - attn_compute = 0 - if activation_recomputation != activation_recomputation.SELECTIVE: + if activation_recomputation == ActivationRecomputation.ATTN_COMPUTE: + memory_attn_compute = 0 + elif activation_recomputation == activation_recomputation.NONE: if flash_attn: memory_attn_compute = (2 * seq_len * batch_size * hidden_dim + 4 * n_head * seq_len * batch_size @@ -598,12 +609,14 @@ def get_activation_memory_per_layer_attn( # dropout mask only requires a single byte per element memory_attn_compute += n_head * seq_len**2 * batch_size / tp_size else: - memory_attn_compute = 0 + raise ValueError( + f'Invalid activation_recomputation: {activation_recomputation}' + ) if is_inference: return max( 3 * bytes_per_activation * seq_len * batch_size * hidden_dim / - sp_size, memory_attn_compute) + tp_size, memory_attn_compute) activation_memory_per_layer_attn = ( seq_len * batch_size * hidden_dim / sp_size + @@ -652,18 +665,19 @@ def get_activation_memory_per_layer_mlp( Returns: float: the memory (in bytes) required to store the activations of the MLP in a transformer layer """ - if (not is_inference - ) and activation_recomputation == ActivationRecomputation.FULL: - return 0 - tp_size = self.parallelism_config.tp_size sp_size = self.parallelism_config.sp_size ep_size = self.parallelism_config.ep_size hidden_dim = self.model_config.hidden_dim - bytes_per_activation = (self.dtype_config.activation_bits / BITS_PER_BYTE) + if is_inference: + assert activation_recomputation == ActivationRecomputation.NONE, f'Inference does not need activation recomputation, but got activation_recomputation = {activation_recomputation}' + + if activation_recomputation == ActivationRecomputation.FULL: + return 0 + bytes_per_1linear_input = bytes_per_gelu_input = bytes_per_2linear_input = bytes_per_activation if mlp_1linear_quant_bits: bytes_per_1linear_input = mlp_1linear_quant_bits / BITS_PER_BYTE @@ -709,8 +723,6 @@ def get_activation_memory_per_layernorm( self, batch_size: int, seq_len: int, - activation_recomputation: - ActivationRecomputation = ActivationRecomputation.NONE, dtype_bytes: int = BYTES_FP32, ) -> float: """Get the memory (in bytes) required to store the activations of a @@ -720,16 +732,12 @@ def get_activation_memory_per_layernorm( Args: batch_size (int): micro batch size seq_len (int): sequence length - activation_recomputation (ActivationRecomputation, optional): \ - activation recomputation strategy. Defaults to ActivationRecomputation.NONE. dtype_bytes (int, optional): number of bytes in the data type for the \ layernorm activation. Defaults to BYTES_FP32. Need to be at least FP16 to maintain accuracy. Returns: float: the memory (in bytes) required to store the activations of a single layernorm in a transformer layer """ - if activation_recomputation == ActivationRecomputation.FULL: - return 0 return (seq_len * batch_size * self.model_config.hidden_dim / self.parallelism_config.sp_size) * dtype_bytes @@ -780,13 +788,18 @@ def get_activation_memory_per_layer( Returns: Union[float, tuple]: the memory (in bytes) required to store the activations of a transformer layer or a tuple of its breakdown """ - if (not is_inference - ) and activation_recomputation == ActivationRecomputation.FULL: + if is_inference: + assert activation_recomputation == ActivationRecomputation.NONE, f'Inference does not need activation recomputation, but got activation_recomputation = {activation_recomputation}' + + if activation_recomputation == ActivationRecomputation.FULL: activation_memory_per_layer = (seq_len * batch_size * self.model_config.hidden_dim * self.dtype_config.activation_bits / BITS_PER_BYTE / self.parallelism_config.tp_size) + logger.info( + f"activation_memory_per_layer for micro batch size {batch_size} with activation_recomputation {activation_recomputation}: {_num_to_string(activation_memory_per_layer)}B" + ) if return_breakdown: return activation_memory_per_layer, 0, 0, 0 else: @@ -818,7 +831,6 @@ def get_activation_memory_per_layer( activation_memory_per_layernorm = self.get_activation_memory_per_layernorm( batch_size, seq_len, - activation_recomputation, layernorm_dtype_bytes, ) @@ -833,11 +845,21 @@ def get_activation_memory_per_layer( f" {_num_to_string(activation_memory_per_layer_mlp)}B , 2 *" f" {_num_to_string(2*activation_memory_per_layernorm)}B))") else: - activation_memory_per_layer = (activation_memory_per_layer_attn + - activation_memory_per_layer_mlp + - 2 * activation_memory_per_layernorm) + if activation_recomputation == ActivationRecomputation.NORM_ATTN_NORM: + activation_memory_per_layer = activation_memory_per_layer_attn + activation_memory_per_layer_mlp + activation_memory_per_layernorm + logger.info( + f"activation_memory_per_layer for micro batch size {batch_size} with activation_recomputation {activation_recomputation}:" + f" {_num_to_string(activation_memory_per_layer)}B" + f" (attn + mlp + layernorm: {_num_to_string(activation_memory_per_layer_attn)}B +" + f" {_num_to_string(activation_memory_per_layer_mlp)}B +" + f" {_num_to_string(activation_memory_per_layernorm)}B)") + else: + activation_memory_per_layer = ( + activation_memory_per_layer_attn + + activation_memory_per_layer_mlp + + 2 * activation_memory_per_layernorm) logger.info( - f"activation_memory_per_layer for micro batch size {batch_size}:" + f"activation_memory_per_layer for micro batch size {batch_size} with activation_recomputation {activation_recomputation}:" f" {_num_to_string(activation_memory_per_layer)}B" f" (attn + mlp + layernorm: {_num_to_string(activation_memory_per_layer_attn)}B +" f" {_num_to_string(activation_memory_per_layer_mlp)}B + 2 *" @@ -986,12 +1008,12 @@ def get_num_flops_bwd_total(self, batch_size: int, seq_len: int) -> int: """ return 2 * self.get_num_flops_fwd_total(batch_size, seq_len) - def get_num_flops_total_selective_recompute_attn(self, batch_size: int, - seq_len: int) -> int: + def get_num_flops_total_attn_compute(self, batch_size: int, + seq_len: int) -> int: """Get the number of floating point operations (flops) for recomputation when - using selective activation recomputation. The count is model-specific and does - not depend on the parallelism strategy. - + selectively checkpointing the attention computation + (QK^T matrix multiply, softmax, softmax dropout, and attention over V.). + The count is model-specific and does not depend on the parallelism strategy. Args: batch_size (int): batch size seq_len (int): sequence length @@ -1125,23 +1147,19 @@ def get_latency_fwd_per_layer_mlp( return max(compute_latency, memory_latency) + alltoall_latency - def get_latency_fwd_per_layer_layernorm( + def get_latency_fwd_per_layernorm( self, batch_size: int, seq_len: int, - activation_recomputation: - ActivationRecomputation = ActivationRecomputation.NONE, dtype_bytes: int = BYTES_FP32, ) -> float: """Get the latency for the forward pass of a single layernorm in a transformer - layer, given the batch size, sequence length, activation recomputation strategy, - and data type. The latency is the memory latency as layernorm is a memory-bound - operation. + layer, given the batch size, sequence length, and data type. The latency is + the memory latency as layernorm is a memory-bound operation. Args: batch_size (int): batch size seq_len (int): sequence length - activation_recomputation (ActivationRecomputation, optional): activation recomputation strategy. Defaults to ActivationRecomputation.NONE. dtype_bytes (int, optional): number of bytes in the data type for the layernorm activation. Defaults to BYTES_FP32. Need to be at least FP16 to maintain accuracy. Returns: @@ -1155,8 +1173,8 @@ def get_latency_fwd_per_layer_layernorm( self.get_gpu_hbm_bandwidth() * 10**9) return activation_memory_latency - def get_latency_fwd_per_layer_tp_comm(self, batch_size: int, seq_len: int, - dtype_bytes: int) -> float: + def get_latency_fwd_per_tp_comm(self, batch_size: int, seq_len: int, + dtype_bytes: int) -> float: """Get the latency of a single allreduce communication across the tensor parallel group in the forward pass of a transformer layer, given the batch size, sequence length, and data type, and assuming a ring allreduce implementation. @@ -1258,30 +1276,28 @@ def get_latency_fwd_per_layer( latency_fwd_per_layer_mlp = self.get_latency_fwd_per_layer_mlp( batch_size, seq_len, is_inference, activation_recomputation) - latency_fwd_per_layer_layernorm = ( - self.get_latency_fwd_per_layer_layernorm( - batch_size, - seq_len, - activation_recomputation, - layernorm_dtype_bytes, - )) + latency_fwd_per_layernorm = self.get_latency_fwd_per_layernorm( + batch_size, + seq_len, + layernorm_dtype_bytes, + ) logger.debug( - f"latency_fwd_per_layer_layernorm: {round(latency_fwd_per_layer_layernorm*1000, 3)} ms" + f"latency_fwd_per_layernorm: {round(latency_fwd_per_layernorm*1000, 3)} ms" ) - latency_fwd_per_layer_tp_comm = self.get_latency_fwd_per_layer_tp_comm( + latency_fwd_per_tp_comm = self.get_latency_fwd_per_tp_comm( batch_size, seq_len, self.dtype_config.activation_bits / BITS_PER_BYTE, ) logger.debug( - f"latency_fwd_per_layer_tp_comm: {round(latency_fwd_per_layer_tp_comm*1000, 3)} ms" + f"latency_fwd_per_tp_comm: {round(latency_fwd_per_tp_comm*1000, 3)} ms" ) latency_fwd_per_layer_shared_dp_comm = self.get_latency_fwd_per_layer_shared_dp_comm( ) - latency_per_layer = latency_fwd_per_layer_attn + latency_fwd_per_layer_mlp + 2 * latency_fwd_per_layer_layernorm + 2 * latency_fwd_per_layer_tp_comm + latency_per_layer = latency_fwd_per_layer_attn + latency_fwd_per_layer_mlp + 2 * latency_fwd_per_layernorm + 2 * latency_fwd_per_tp_comm if ds_zero > DSZeRO.STAGE_1 and latency_fwd_per_layer_shared_dp_comm > latency_per_layer: logger.warning( @@ -1294,15 +1310,15 @@ def get_latency_fwd_per_layer( f"latency_per_layer: {round(latency_per_layer*1000, 3)} ms (max(attn + mlp + 2*layernorm + 2*tp_comm, shared_dp_comm):" f" max({round(latency_fwd_per_layer_attn*1000, 3)} +" f" {round(latency_fwd_per_layer_mlp*1000, 3)} +" - f" {round(2*latency_fwd_per_layer_layernorm*1000, 3)} +" - f" {round(2*latency_fwd_per_layer_tp_comm*1000, 3)}," + f" {round(2*latency_fwd_per_layernorm*1000, 3)} +" + f" {round(2*latency_fwd_per_tp_comm*1000, 3)}," f" {round(latency_fwd_per_layer_shared_dp_comm*1000, 3)}))") breakdown_per_layer = { "attn": latency_fwd_per_layer_attn, "mlp": latency_fwd_per_layer_mlp, - "layernorm": 2 * latency_fwd_per_layer_layernorm, - "tp_comm": 2 * latency_fwd_per_layer_tp_comm, + "layernorm": 2 * latency_fwd_per_layernorm, + "tp_comm": 2 * latency_fwd_per_tp_comm, "sharded_dp_comm": latency_fwd_per_layer_shared_dp_comm } @@ -1327,7 +1343,7 @@ def get_latency_fwd_input_embedding( memory_latency = (self.model_config.vocab_size * self.model_config.hidden_dim * dtype_bytes / (self.get_gpu_hbm_bandwidth() * 10**9)) - comm_latency = self.get_latency_fwd_per_layer_tp_comm( + comm_latency = self.get_latency_fwd_per_tp_comm( batch_size, seq_len, dtype_bytes) return memory_latency + comm_latency @@ -2065,7 +2081,6 @@ def training( activation_memory_batch_size_1 += self.get_activation_memory_per_layernorm( 1, seq_len, - activation_recomputation, layernorm_dtype_bytes, ) @@ -2119,7 +2134,6 @@ def training( activation_memory_per_gpu += self.get_activation_memory_per_layernorm( batch_size_per_gpu, seq_len, - activation_recomputation, layernorm_dtype_bytes, ) @@ -2143,10 +2157,12 @@ def training( if activation_recomputation == ActivationRecomputation.FULL: num_flops_recompute = num_flops_fwd_total - elif activation_recomputation == ActivationRecomputation.SELECTIVE: - num_flops_recompute = ( - self.get_num_flops_total_selective_recompute_attn( - batch_size_per_gpu, seq_len)) + elif activation_recomputation == ActivationRecomputation.NORM_ATTN_NORM or activation_recomputation == ActivationRecomputation.ATTN: + num_flops_recompute = self.get_num_flops_fwd_per_layer_attn( + batch_size_per_gpu, seq_len) + elif activation_recomputation == ActivationRecomputation.ATTN_COMPUTE: + num_flops_recompute = self.get_num_flops_total_attn_compute( + batch_size_per_gpu, seq_len) if num_flops_recompute < 0.05 * num_flops_fwd_total: logger.warning( f"num_flops_recompute ({num_flops_recompute}) is too large to" @@ -2157,13 +2173,21 @@ def training( num_flops_total_per_micro_batch = (num_flops_fwd_total + num_flops_bwd_total + num_flops_recompute) - logger.info( "num_flops_total_per_micro_batch:" f" {_num_to_string(num_flops_total_per_micro_batch, divisor=1000)} ({_num_to_string(num_flops_fwd_total, divisor=1000)} fwd" f" + {_num_to_string(num_flops_bwd_total, divisor=1000)} bwd +" f" {_num_to_string(num_flops_recompute, divisor=1000)} recompute)") + # estimated by flops only: + latency_per_micro_batch_using_flops = num_flops_total_per_micro_batch / ( + (self.parallelism_config.tp_size * self.parallelism_config.pp_size) + * self.get_TFLOPS_per_gpu() * 1e12) + logger.info( + f'latency_per_micro_batch_using_flops = {round(latency_per_micro_batch_using_flops*1000, 3)} ms' + ) + latency_per_iter_using_flops = latency_per_micro_batch_using_flops * gradient_accumulation_steps + latency_fwd, latency_fwd_breakdown = self.get_latency_fwd( batch_size_per_gpu, seq_len, @@ -2172,29 +2196,37 @@ def training( layernorm_dtype_bytes=layernorm_dtype_bytes, ds_zero=ds_zero, ) - # estimated by flops only: - latency_per_micro_batch_using_flops = num_flops_total_per_micro_batch / ( - (self.parallelism_config.tp_size * self.parallelism_config.pp_size) - * self.get_TFLOPS_per_gpu() * 1e12) - logger.info( - f'latency_per_micro_batch_using_flops = {round(latency_per_micro_batch_using_flops*1000, 3)} ms' - ) - latency_per_micro_batch = latency_fwd * 3 + num_flops_recompute / ( - (self.parallelism_config.tp_size * self.parallelism_config.pp_size) - * self.get_TFLOPS_per_gpu() * 1e12) - latency_weight_update = self.get_latency_weight_update() + if activation_recomputation == ActivationRecomputation.FULL: + latency_recompute = latency_fwd + elif activation_recomputation == ActivationRecomputation.NORM_ATTN_NORM: + latency_recompute = self.get_latency_fwd_per_layer_attn( + batch_size_per_gpu, seq_len, False, activation_recomputation + ) + 2 * self.get_latency_fwd_per_layernorm + elif activation_recomputation == ActivationRecomputation.ATTN: + latency_recompute = self.get_latency_fwd_per_layer_attn( + batch_size_per_gpu, seq_len, False, activation_recomputation) + elif activation_recomputation == ActivationRecomputation.ATTN_COMPUTE: + latency_recompute = self.get_num_flops_total_attn_compute( + batch_size_per_gpu, seq_len) / ( + (self.parallelism_config.tp_size * + self.parallelism_config.pp_size) * + self.get_TFLOPS_per_gpu() * 1e12) + elif activation_recomputation == ActivationRecomputation.NONE: + latency_recompute = 0 + latency_per_micro_batch = latency_fwd * 3 + latency_recompute + latency_weight_update = self.get_latency_weight_update() latency_per_iter = ( latency_per_micro_batch * gradient_accumulation_steps + latency_weight_update) - latency_per_iter_using_flops = latency_per_micro_batch_using_flops * gradient_accumulation_steps - logger.info( - f"latency_per_micro_batch: {round(latency_per_micro_batch * 1000, 3)} ms, " + f"latency_per_micro_batch: {round(latency_per_micro_batch * 1000, 3)} ms ({round(latency_fwd * 1000, 3)} latency_fwd * 3 + {round(latency_recompute * 1000, 3)} latency_recompute)" + ) + logger.info( f"latency_per_iter: {round(latency_per_iter * 1000, 3)} ms " - f"({round(latency_per_micro_batch * 1000, 3)} ms latency_fwd * {gradient_accumulation_steps} gradient_accumulation_steps + {round(latency_weight_update * 1000, 3)} ms weight_update)" + f"({round(latency_per_micro_batch * 1000, 3)} ms latency_per_micro_batch * {gradient_accumulation_steps} gradient_accumulation_steps + {round(latency_weight_update * 1000, 3)} ms weight_update)" ) total_num_gpus = (self.parallelism_config.tp_size * diff --git a/tests/test_training.py b/tests/test_training.py index 49b9f89..e3955fb 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -281,7 +281,8 @@ def test_training_mt_nlg_1(): ds_zero=DSZeRO.STAGE_3, ) - assert within_range(summary_dict["latency_per_iter"], 60.0, TOLERANCE) + assert within_range(summary_dict["latency_per_iter_using_flops"], 60.0, + TOLERANCE) # deepspeed megatron mt-nlg-530b paper https://arxiv.org/abs/2201.11990