Skip to content

Commit

Permalink
Fix weight computation for MLP (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
weimingzha0 authored Oct 10, 2023
1 parent b247ae1 commit 5fcc4b2
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def get_num_params_per_layer_mlp(self) -> int:
Returns:
int: the number of parameters in the two MLP linear layers
"""
return 8 * self.model_config.hidden_dim**2 # 4+4
return 2 * self.model_config.hidden_dim*self.model_config.ffn_embed_dim

def get_num_params_per_layer(self) -> int:
"""Get the number of parameters in a transformer layer, including the
Expand Down

0 comments on commit 5fcc4b2

Please sign in to comment.