Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Oct 19, 2023
1 parent c15cf33 commit 6bad7b4
Showing 1 changed file with 47 additions and 25 deletions.
72 changes: 47 additions & 25 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,17 +369,21 @@ def get_memory_weight_per_layer(
Returns:
float: the memory (in bytes) required to store the weights of a transformer layer
"""
memory_weight_per_layer = (
(
self.get_num_params_per_layer_attn()
+ self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router() + self.get_num_params_per_layer_layernorm()
)
* self.dtype_config.weight_bits
/ BITS_PER_BYTE
/ self.parallelism_config.tp_size
)
if ds_zero == DSZeRO.STAGE_3:
memory_weight_per_layer /= self.parallelism_config.dp_size
sharded_dp_size = self.parallelism_config.dp_size
else:
sharded_dp_size = 1

memory_weight_attn_per_layer = self.get_num_params_per_layer_attn() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size /sharded_dp_size

memory_weight_mlp_per_layer = (self.get_num_params_per_layer_mlp() / self.parallelism_config.ep_size + self.get_num_params_per_layer_router()) * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size

memory_weight_layernorm_per_layer = self.get_num_params_per_layer_layernorm() * self.dtype_config.weight_bits / BITS_PER_BYTE / self.parallelism_config.tp_size

memory_weight_per_layer = memory_weight_attn_per_layer + memory_weight_mlp_per_layer + memory_weight_layernorm_per_layer

logger.debug(f'memory_weight_attn_per_layer: {_num_to_string(memory_weight_attn_per_layer)}B, memory_weight_mlp_per_layer: {_num_to_string(memory_weight_mlp_per_layer)}B, memory_weight_layernorm_per_layer: {_num_to_string(memory_weight_layernorm_per_layer)}B')

return memory_weight_per_layer

def get_memory_optimizer_state_per_layer(
Expand Down Expand Up @@ -534,6 +538,9 @@ def get_memory_activation_per_layer_mlp(
seq_len: int,
is_inference: bool = True,
activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE,
activation_quant_bits: int = None,
recompute_gelu: bool = False,
with_dropout: bool = False,
) -> float:
"""Get the memory (in bytes) required to store the activations of the
MLP in a transformer layer, given the batch size, sequence length, and
Expand All @@ -556,9 +563,12 @@ def get_memory_activation_per_layer_mlp(
sp_size = self.parallelism_config.sp_size
ep_size = self.parallelism_config.ep_size
hidden_dim = self.model_config.hidden_dim

bytes_per_activation = (
self.dtype_config.activation_bits / BITS_PER_BYTE
)
if activation_quant_bits:
bytes_per_activation = activation_quant_bits/ BITS_PER_BYTE

if is_inference:
return (5 * seq_len * batch_size * hidden_dim / sp_size) * bytes_per_activation
Expand All @@ -568,18 +578,20 @@ def get_memory_activation_per_layer_mlp(
# dropout mask only requires a single byte per element
drop_out_mask = (
seq_len * batch_size * hidden_dim / sp_size
)
) if with_dropout else 0

print(f'XXXX recompute_gelu = {recompute_gelu}')
if self.model_config.moe_num_experts == 1:
memory_activation_per_layer_mlp = (
(1 * seq_len * batch_size * hidden_dim / sp_size)
+ (2 * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio / tp_size)
+ ((2 if not recompute_gelu else 1) * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio / tp_size)
) * bytes_per_activation + drop_out_mask
else:
memory_activation_per_layer_mlp = self.model_config.moe_top_k *(
(1 * seq_len * batch_size * hidden_dim / sp_size)
+ (2 * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio * self.model_config.moe_num_experts/ ep_size / tp_size)
+ ((2 if not recompute_gelu else 1) * seq_len * batch_size * hidden_dim * self.model_config.expansion_ratio * self.model_config.moe_num_experts/ ep_size / tp_size)
) * bytes_per_activation + drop_out_mask

return memory_activation_per_layer_mlp

def get_memory_activation_per_layer_layernorm(
Expand Down Expand Up @@ -620,6 +632,8 @@ def get_memory_activation_per_layer(
is_inference: bool = True,
activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE,
layernorm_dtype_bytes: int = BYTES_FP32,
mlp_activation_quant_bits: int = None,
mlp_recompute_gelu: bool = False,
) -> float:
"""Get the memory (in bytes) required to store the activations of a
transformer layer, given the batch size, sequence length, and whether
Expand Down Expand Up @@ -656,7 +670,7 @@ def get_memory_activation_per_layer(

memory_activation_per_layer_mlp = (
self.get_memory_activation_per_layer_mlp(
batch_size, seq_len, is_inference, activation_recomputation
batch_size, seq_len, is_inference, activation_recomputation, activation_quant_bits=mlp_activation_quant_bits, recompute_gelu=mlp_recompute_gelu,
)
)

Expand All @@ -669,27 +683,25 @@ def get_memory_activation_per_layer(
)
)

logger.debug(
"memory_activation_per_layer_attn:"
f" {_num_to_string(memory_activation_per_layer_attn)}B,"
" memory_activation_per_layer_mlp:"
f" {_num_to_string(memory_activation_per_layer_mlp)}B,"
" memory_activation_per_layer_layernorm:"
f" {_num_to_string(memory_activation_per_layer_layernorm)}B"
)

if is_inference:
memory_activation_per_layer = max(memory_activation_per_layer_attn, memory_activation_per_layer_mlp, memory_activation_per_layer_layernorm)
logger.debug(
f"memory_activation_per_layer for batch_size {batch_size}:"
f" {_num_to_string(memory_activation_per_layer)}B"
f" (max(attn, mlp, layernorm): max({_num_to_string(memory_activation_per_layer_attn)}B ,"
f" {_num_to_string(memory_activation_per_layer_mlp)}B , 2 *"
f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B))"
)
else:
memory_activation_per_layer = (
memory_activation_per_layer_attn
+ memory_activation_per_layer_mlp
+ 2 * memory_activation_per_layer_layernorm
)
logger.debug(
"memory_activation_per_layer:"
f"memory_activation_per_layer for batch_size {batch_size}:"
f" {_num_to_string(memory_activation_per_layer)}B"
f" ({_num_to_string(memory_activation_per_layer_attn)}B +"
f" (attn + mlp + layernorm: {_num_to_string(memory_activation_per_layer_attn)}B +"
f" {_num_to_string(memory_activation_per_layer_mlp)}B + 2 *"
f" {_num_to_string(2*memory_activation_per_layer_layernorm)}B)"
)
Expand Down Expand Up @@ -1793,6 +1805,8 @@ def training(
activation_recomputation: ActivationRecomputation = ActivationRecomputation.NONE,
ds_zero: DSZeRO = DSZeRO.NONE,
layernorm_dtype_bytes: int = BYTES_FP32,
mlp_activation_quant_bits: int = None,
mlp_recompute_gelu: bool = False,
output_dir: str = None,
output_file_suffix: str = "",
) -> dict:
Expand Down Expand Up @@ -1893,6 +1907,8 @@ def training(
is_inference=False,
activation_recomputation=activation_recomputation,
layernorm_dtype_bytes=layernorm_dtype_bytes,
mlp_activation_quant_bits=mlp_activation_quant_bits,
mlp_recompute_gelu=mlp_recompute_gelu,
)
* self.model_config.num_layers
)
Expand Down Expand Up @@ -1920,6 +1936,8 @@ def training(
is_inference=False,
activation_recomputation=activation_recomputation,
layernorm_dtype_bytes=layernorm_dtype_bytes,
mlp_activation_quant_bits=mlp_activation_quant_bits,
mlp_recompute_gelu=mlp_recompute_gelu,
)
* self.model_config.num_layers
)
Expand Down Expand Up @@ -2215,6 +2233,8 @@ def train(
ep_size: int = 1,
total_num_gpus: int = None,
layernorm_dtype_bytes: int = BYTES_FP32,
mlp_activation_quant_bits: int = None,
mlp_recompute_gelu: bool = False,
achieved_tflops: float = None,
flops_efficiency: float = None,
hbm_memory_efficiency: float = HBM_MEMORY_EFFICIENCY,
Expand Down Expand Up @@ -2309,6 +2329,8 @@ def train(
),
ds_zero=DSZeRO(ds_zero),
layernorm_dtype_bytes=layernorm_dtype_bytes,
mlp_activation_quant_bits=mlp_activation_quant_bits,
mlp_recompute_gelu=mlp_recompute_gelu,
output_dir=output_dir,
output_file_suffix=output_file_suffix,
)
Expand Down

0 comments on commit 6bad7b4

Please sign in to comment.