Skip to content

Commit

Permalink
add allgather activation in moe
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Nov 22, 2023
1 parent 1f73c81 commit 820e199
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,8 +702,13 @@ def get_activation_memory_per_layer_mlp(
# MoE MLP
# The router stores inputs batch size * seq len * feature dim
# The softmax stores inputs batch size * seq len * feature dim
# W1 stores on average TopK * batch size * seq len * feature dim
activation_memory_per_layer_mlp = 2 * bytes_per_activation * seq_len * batch_size * hidden_dim / sp_size

# The WeightedSum of the all2all+WeightedSum stores ftk + batch size * seq len * expert count elements
activation_memory_per_layer_mlp += bytes_per_1linear_input * seq_len * batch_size * hidden_dim * self.model_config.moe_top_k
activation_memory_per_layer_mlp += bytes_per_1linear_input * batch_size * seq_len * self.model_config.moe_top_k

# W1 stores on average TopK * batch size * seq len * feature dim
activation_memory_per_layer_mlp += bytes_per_1linear_input * seq_len * batch_size * hidden_dim * self.model_config.moe_top_k / sp_size
else:
# dense MLP
Expand Down Expand Up @@ -810,7 +815,7 @@ def get_activation_memory_per_layer(
f"activation_memory_per_layer for micro batch size {batch_size} with activation_recomputation {activation_recomputation}: {_num_to_string(activation_memory_per_layer)}B"
)
if return_breakdown:
return activation_memory_per_layer, 0, 0, 0
return activation_memory_per_layer, 0, 0, activation_memory_per_layer
else:
return activation_memory_per_layer

Expand Down Expand Up @@ -1658,8 +1663,8 @@ def inference(
)

if use_kv_cache:
if (batch_size_per_gpu *
(seq_len + num_tokens_to_generate) < self.get_pivot()):
if (batch_size_per_gpu * (seq_len + num_tokens_to_generate)
< self.get_pivot()):
logger.warning(
"kv_cache is only useful when batch_size *"
" (seq+num_tokens_to_generate)"
Expand Down Expand Up @@ -1872,16 +1877,16 @@ def config_batch_size_and_gradient_accumulation_steps(
gradient_accumulation_steps = global_batch_size // (
batch_size_per_gpu * dp_size)
assert (global_batch_size % (batch_size_per_gpu * dp_size) == 0
and gradient_accumulation_steps > 0
), "no valid gradient_accumulation_steps, {assert_msg}"
and gradient_accumulation_steps
> 0), "no valid gradient_accumulation_steps, {assert_msg}"
elif global_batch_size and gradient_accumulation_steps:
# batch_size_per_gpu is None, the other two are not None
batch_size_per_gpu = global_batch_size // (
gradient_accumulation_steps * dp_size)
assert (global_batch_size %
(gradient_accumulation_steps * dp_size) == 0
and batch_size_per_gpu > 0
), "no valid batch_size_per_gpu, {assert_msg}"
and batch_size_per_gpu
> 0), "no valid batch_size_per_gpu, {assert_msg}"
elif batch_size_per_gpu and gradient_accumulation_steps or batch_size_per_gpu:
# batch_size_per_gpu is not None
if batch_size_per_gpu > max_batch_size_per_gpu:
Expand Down Expand Up @@ -1916,9 +1921,9 @@ def config_batch_size_and_gradient_accumulation_steps(
else:
# (global_batch_size and batch_size_per_gpu are None) or (all are None)
batch_size_per_gpu = max_batch_size_per_gpu
gradient_accumulation_steps = (1 if
gradient_accumulation_steps is None
else gradient_accumulation_steps)
gradient_accumulation_steps = (1 if gradient_accumulation_steps
is None else
gradient_accumulation_steps)
global_batch_size = (batch_size_per_gpu *
gradient_accumulation_steps *
self.parallelism_config.dp_size)
Expand Down

0 comments on commit 820e199

Please sign in to comment.