Skip to content

Commit

Permalink
Modify the bug in weight memory calculation (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
BhAem authored Mar 1, 2024
1 parent 9932ff4 commit 6bacd4f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,7 @@ def inference(

weight_memory_embedding_per_gpu = self.get_memory_embedding(ds_zero)
weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu = [
x * self.model_config.num_layers
x * num_layers_per_gpu
for x in self.get_weight_memory_per_layer(ds_zero,
return_breakdown=True)
]
Expand Down

0 comments on commit 6bacd4f

Please sign in to comment.