From ab2fe8806111ff3b45f3fbcc9dfddf8454251746 Mon Sep 17 00:00:00 2001 From: ddilbaz Date: Thu, 7 Nov 2024 19:50:15 +0000 Subject: [PATCH] Golden function for tt and torch results --- tt_torch/dynamo/backend.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 8ad9f749..6008821c 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -110,12 +110,13 @@ def compile_process(receiver, sender): class Executor: - def __init__(self, gm, compiler_config=None): + def __init__(self, gm, compiler_config=None, required_atol=1e-2): self.gm = gm self.binary = None if compiler_config is None: compiler_config = CompilerConfig() self.compiler_config = compiler_config + self.required_atol = required_atol def set_binary(self, binary): self.binary = binary @@ -298,7 +299,10 @@ def run_gm_op_by_op(self, *inputs): == CompileDepth.EXECUTE_OP_BY_OP and binary is not None ): - tensor = self.run_op(binary, *args) + tt_tensor = self.run_op(binary, *args) + golden_tensor = node.target(*args, **node.kwargs) + atol = torch.max(torch.abs(golden_tensor - tt_tensor)).item() + assert (atol <= self.required_atol), f"ATOL too high: {atol}" op.compilation_status = OpCompilationStatus.EXECUTED else: tensor = node.target(*args, **node.kwargs)