diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py index c908beb..71a3d6d 100644 --- a/llm_analysis/analysis.py +++ b/llm_analysis/analysis.py @@ -1582,7 +1582,7 @@ def inference( weight_memory_embedding_per_gpu = self.get_memory_embedding(ds_zero) weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu = [ - x * self.model_config.num_layers + x * num_layers_per_gpu for x in self.get_weight_memory_per_layer(ds_zero, return_breakdown=True) ]