From 820e199d99dcfac4cb2847610f1cad53c32b1885 Mon Sep 17 00:00:00 2001 From: Cheng Li Date: Wed, 22 Nov 2023 10:31:25 -0800 Subject: [PATCH] add allgather activation in moe --- llm_analysis/analysis.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index c912cf2..6d08b52 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -702,8 +702,13 @@ def get_activation_memory_per_layer_mlp( # MoE MLP # The router stores inputs batch size * seq len * feature dim # The softmax stores inputs batch size * seq len * feature dim - # W1 stores on average TopK * batch size * seq len * feature dim activation_memory_per_layer_mlp = 2 * bytes_per_activation * seq_len * batch_size * hidden_dim / sp_size + + # The WeightedSum of the all2all+WeightedSum stores ftk + batch size * seq len * expert count elements + activation_memory_per_layer_mlp += bytes_per_1linear_input * seq_len * batch_size * hidden_dim * self.model_config.moe_top_k + activation_memory_per_layer_mlp += bytes_per_1linear_input * batch_size * seq_len * self.model_config.moe_top_k + + # W1 stores on average TopK * batch size * seq len * feature dim activation_memory_per_layer_mlp += bytes_per_1linear_input * seq_len * batch_size * hidden_dim * self.model_config.moe_top_k / sp_size else: # dense MLP @@ -810,7 +815,7 @@ def get_activation_memory_per_layer( f"activation_memory_per_layer for micro batch size {batch_size} with activation_recomputation {activation_recomputation}: {_num_to_string(activation_memory_per_layer)}B" ) if return_breakdown: - return activation_memory_per_layer, 0, 0, 0 + return activation_memory_per_layer, 0, 0, activation_memory_per_layer else: return activation_memory_per_layer @@ -1658,8 +1663,8 @@ def inference( ) if use_kv_cache: - if (batch_size_per_gpu * - (seq_len + num_tokens_to_generate) < self.get_pivot()): + if (batch_size_per_gpu * (seq_len + num_tokens_to_generate) + < self.get_pivot()): logger.warning( "kv_cache is only useful when batch_size *" " (seq+num_tokens_to_generate)" @@ -1872,16 +1877,16 @@ def config_batch_size_and_gradient_accumulation_steps( gradient_accumulation_steps = global_batch_size // ( batch_size_per_gpu * dp_size) assert (global_batch_size % (batch_size_per_gpu * dp_size) == 0 - and gradient_accumulation_steps > 0 - ), "no valid gradient_accumulation_steps, {assert_msg}" + and gradient_accumulation_steps + > 0), "no valid gradient_accumulation_steps, {assert_msg}" elif global_batch_size and gradient_accumulation_steps: # batch_size_per_gpu is None, the other two are not None batch_size_per_gpu = global_batch_size // ( gradient_accumulation_steps * dp_size) assert (global_batch_size % (gradient_accumulation_steps * dp_size) == 0 - and batch_size_per_gpu > 0 - ), "no valid batch_size_per_gpu, {assert_msg}" + and batch_size_per_gpu + > 0), "no valid batch_size_per_gpu, {assert_msg}" elif batch_size_per_gpu and gradient_accumulation_steps or batch_size_per_gpu: # batch_size_per_gpu is not None if batch_size_per_gpu > max_batch_size_per_gpu: @@ -1916,9 +1921,9 @@ def config_batch_size_and_gradient_accumulation_steps( else: # (global_batch_size and batch_size_per_gpu are None) or (all are None) batch_size_per_gpu = max_batch_size_per_gpu - gradient_accumulation_steps = (1 if - gradient_accumulation_steps is None - else gradient_accumulation_steps) + gradient_accumulation_steps = (1 if gradient_accumulation_steps + is None else + gradient_accumulation_steps) global_batch_size = (batch_size_per_gpu * gradient_accumulation_steps * self.parallelism_config.dp_size)