Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Mar 7, 2024
1 parent 8135689 commit 77ac4e8
Showing 1 changed file with 52 additions and 34 deletions.
86 changes: 52 additions & 34 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,6 @@ def get_activation_memory_per_layer_mlp(
bytes_per_gelu_input = mlp_activation_quant_bits / BITS_PER_BYTE
bytes_per_2linear_input = mlp_activation_quant_bits / BITS_PER_BYTE

num_experts_per_gpu = self.model_config.moe_num_experts / ep_size

if is_inference:
return max(
bytes_per_1linear_input,
Expand Down Expand Up @@ -1952,6 +1950,8 @@ def training(
activation_recomputation:
ActivationRecomputation = ActivationRecomputation.NONE,
ds_zero: DSZeRO = DSZeRO.NONE,
fwd_prefetch: bool = True,
bwd_prefetch: bool = True,
layernorm_dtype_bytes: int = BYTES_FP32,
master_weights_dtype_bytes: int = BYTES_FP32,
other_op_bytes: int = None,
Expand Down Expand Up @@ -2047,8 +2047,17 @@ def training(
self.weight_grad_op_state_memory_per_gpu = (
weight_memory_per_gpu + optimizer_state_memory_per_gpu +
gradient_memory_per_gpu)

estimated_fwd_prefetch_memory_usage = unsharded_weight_memory_embedding + unsharded_weight_memory_per_layer

estimated_bwd_prefetch_memory_usage = (
3 + int(fwd_prefetch) +
int(bwd_prefetch)) * (unsharded_weight_memory_per_layer)

memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 -
self.weight_grad_op_state_memory_per_gpu)
self.weight_grad_op_state_memory_per_gpu -
max(estimated_fwd_prefetch_memory_usage,
estimated_bwd_prefetch_memory_usage))

logger.info(
f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B"
Expand All @@ -2057,37 +2066,45 @@ def training(
" optimizer_state_memory_per_gpu:"
f" {_num_to_string(optimizer_state_memory_per_gpu)}B,"
" gradient_memory_per_gpu:"
f" {_num_to_string(gradient_memory_per_gpu)}B, memory_left:"
f" {_num_to_string(memory_left)}B")
f" {_num_to_string(gradient_memory_per_gpu)}B",
" estimated_fwd_prefetch_memory_usage:"
f" {_num_to_string(estimated_fwd_prefetch_memory_usage)}B",
" estimated_bwd_prefetch_memory_usage:"
f" {_num_to_string(estimated_bwd_prefetch_memory_usage)}B",
" memory_left:"
f" {_num_to_string(memory_left)}B",
)

if memory_left < 0:
logger.warning(
"model weight/optimizer stage/gradient is too large (requiring"
f" {_num_to_string(weight_memory_per_gpu)}B /"
f" {_num_to_string(optimizer_state_memory_per_gpu)}B /"
f" {_num_to_string(gradient_memory_per_gpu)}B) to fit in total GPU"
" memory")
"model weight/optimizer state/gradient or fwd/bwd prefetch memory usage is too large to fit in GPU memory"
)

# With pipeline parallelism, each stage contains L/p layers so the first stage must store p ×L/p = L layers worth of activations regardless of the pipeline parallel size p; activation memory required for the input embeddings, the last layer-norm, and the output layer are ignored here. Refer to https://arxiv.org/abs/2205.05198 for more details.

activation_memory_batch_size_1, activation_memory_attn_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1 = [
activation_memory_per_layer_batch_size_1, attn_activation_memory_per_layer_batch_size_1, mlp_activation_memory_per_layer_batch_size_1, layernorm_activation_memory_per_layer_batch_size_1 = self.get_activation_memory_per_layer(
1,
seq_len,
is_inference=False,
activation_recomputation=activation_recomputation,
layernorm_dtype_bytes=layernorm_dtype_bytes,
flash_attn=flash_attn,
softmax_dropout=softmax_dropout,
mlp_activation_quant_bits=mlp_activation_quant_bits,
mlp_1linear_quant_bits=mlp_1linear_quant_bits,
mlp_gelu_input_quant_bits=mlp_gelu_input_quant_bits,
mlp_2linear_quant_bits=mlp_2linear_quant_bits,
mlp_recompute_gelu=mlp_recompute_gelu,
return_breakdown=True,
)
activation_memory_batch_size_1, attn_activation_memory_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1 = [
x * self.model_config.num_layers
for x in self.get_activation_memory_per_layer(
1,
seq_len,
is_inference=False,
activation_recomputation=activation_recomputation,
layernorm_dtype_bytes=layernorm_dtype_bytes,
flash_attn=flash_attn,
softmax_dropout=softmax_dropout,
mlp_activation_quant_bits=mlp_activation_quant_bits,
mlp_1linear_quant_bits=mlp_1linear_quant_bits,
mlp_gelu_input_quant_bits=mlp_gelu_input_quant_bits,
mlp_2linear_quant_bits=mlp_2linear_quant_bits,
mlp_recompute_gelu=mlp_recompute_gelu,
return_breakdown=True,
)
for x in (activation_memory_per_layer_batch_size_1,
attn_activation_memory_per_layer_batch_size_1,
mlp_activation_memory_per_layer_batch_size_1,
layernorm_activation_memory_per_layer_batch_size_1)
]

activation_memory_embedding_output_batch_size_1 = self.get_activation_memory_output_embedding(
1, seq_len)
logger.info(
Expand Down Expand Up @@ -2124,7 +2141,7 @@ def training(
)

if batch_size_per_gpu == 1:
activation_memory_per_gpu, activation_memory_attn_per_gpu, activation_memory_mlp_per_gpu, activation_memory_layernorm_per_gpu = activation_memory_batch_size_1, activation_memory_attn_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1
activation_memory_per_gpu, activation_memory_attn_per_gpu, activation_memory_mlp_per_gpu, activation_memory_layernorm_per_gpu = activation_memory_batch_size_1, attn_activation_memory_batch_size_1, mlp_activation_memory_batch_size_1, layernorm_activation_memory_batch_size_1
else:
activation_memory_per_gpu, activation_memory_attn_per_gpu, activation_memory_mlp_per_gpu, activation_memory_layernorm_per_gpu = [
x * self.model_config.num_layers
Expand Down Expand Up @@ -2383,15 +2400,12 @@ def training(
"(weight+op_state+act)_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
activation_memory_per_gpu,
"end_fwd_memory_per_gpu":
"estimated_peak_fwd_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
activation_memory_per_gpu + unsharded_weight_memory_embedding +
unsharded_weight_memory_per_layer,
"end_last3_bwd_memory_per_gpu":
activation_memory_per_gpu + estimated_fwd_prefetch_memory_usage,
"estimated_peak_bwd_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
activation_memory_per_gpu + 3 *
(unsharded_weight_memory_per_layer +
unsharded_weight_memory_mlp_per_layer),
activation_memory_per_gpu + estimated_bwd_prefetch_memory_usage,
"memory_left_per_gpu":
memory_left,
"latency_per_micro_batch":
Expand Down Expand Up @@ -2536,6 +2550,8 @@ def train(
total_num_tokens: int = None,
activation_recomputation: int = 0,
ds_zero: int = 0,
fwd_prefetch: bool = True,
bwd_prefetch: bool = True,
dp_size: int = None,
tp_size: int = 1,
pp_size: int = 1,
Expand Down Expand Up @@ -2658,6 +2674,8 @@ def train(
activation_recomputation=ActivationRecomputation(
activation_recomputation),
ds_zero=DSZeRO(ds_zero),
fwd_prefetch=fwd_prefetch,
bwd_prefetch=bwd_prefetch,
layernorm_dtype_bytes=layernorm_dtype_bytes,
master_weights_dtype_bytes=master_weights_dtype_bytes,
other_op_bytes=other_op_bytes,
Expand Down

0 comments on commit 77ac4e8

Please sign in to comment.