From aac80d111840ccc324a105d499060e814ca7f2c0 Mon Sep 17 00:00:00 2001 From: mtairum Date: Tue, 4 Feb 2025 18:34:53 +0000 Subject: [PATCH] #0: Fix MLP W3 kernel config --- models/demos/llama3/tt/llama_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index 1384c101aaa..c4f0971abd9 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -106,7 +106,9 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: x, self.w3, compute_kernel_config=( - self.args.compute_kernel_config_lofi if self.four_bit_mlp else self.args.compute_kernel_config_hifi2 + self.args.compute_kernel_config_lofi + if self.four_bit_mlp + else self.args.compute_kernel_config_hifi2_fp16 ), core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, dtype=ttnn.bfloat16,