diff --git a/llm_analysis/analysis.py b/llm_analysis/analysis.py
index c908beb..71a3d6d 100644
--- a/llm_analysis/analysis.py
+++ b/llm_analysis/analysis.py
@@ -1582,7 +1582,7 @@ def inference(
 
         weight_memory_embedding_per_gpu = self.get_memory_embedding(ds_zero)
         weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu = [
-            x * self.model_config.num_layers
+            x * num_layers_per_gpu
             for x in self.get_weight_memory_per_layer(ds_zero,
                                                       return_breakdown=True)
         ]