From 16fb250e93f3e24fb53bee3122cb4b1eb302e604 Mon Sep 17 00:00:00 2001 From: brataTT Date: Tue, 7 Jan 2025 15:24:33 +0000 Subject: [PATCH] Remove unused getitems in op by op flow [#105] --- tests/torch/test_basic.py | 23 +++++++++++++++++++++++ tests/torch/test_maxpool2d.py | 27 +++++++++++++++++++++++++++ tt_torch/dynamo/backend.py | 12 ++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 tests/torch/test_maxpool2d.py diff --git a/tests/torch/test_basic.py b/tests/torch/test_basic.py index 835310f..527bfde 100644 --- a/tests/torch/test_basic.py +++ b/tests/torch/test_basic.py @@ -412,6 +412,29 @@ def forward(self, x): ) +def test_unused_output(): + class Basic_var_only(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + var, mean = torch.var_mean(x) + return var + + class Basic_mean_only(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + var, mean = torch.var_mean(x) + return mean + + for module in [Basic_var_only, Basic_mean_only]: + cc = CompilerConfig() + cc.compile_depth = tt_torch.tools.utils.CompileDepth.COMPILE_OP_BY_OP + verify_module(module(), input_shapes=[(256, 256)], compiler_config=cc) + + @pytest.mark.parametrize( ("input_range", "input_shapes", "input_type"), [ diff --git a/tests/torch/test_maxpool2d.py b/tests/torch/test_maxpool2d.py new file mode 100644 index 0000000..636fb74 --- /dev/null +++ b/tests/torch/test_maxpool2d.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch +from torch import nn +import pytest + +import tt_torch +from tt_torch.tools.verify import verify_module +from tt_torch.tools.utils import CompilerConfig, CompileDepth + + +def test_maxpool2d(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d(x, kernel_size=2, stride=2) + + cc = CompilerConfig() + cc.compile_depth = CompileDepth.EXECUTE_OP_BY_OP + verify_module( + Basic(), + inputs=[torch.randn(1, 1, 224, 224).to(torch.bfloat16)], + compiler_config=cc, + ) diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 49cf5b9..1afb006 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -188,7 +188,15 @@ def compile_op(self, node, *inputs, **kwargs): ): getitem_nodes = [] graph_node.meta["val"] = node.meta["val"] + for idx, tensor_meta in enumerate(node.meta["tensor_meta"]): + # filter out unused outputs that do not exist in the reduced graph + users = self.gm.graph.find_nodes( + op="call_function", target=operator.getitem + ) + if not any(user_node.args == (node, idx) for user_node in users): + continue + getitem_node = graph.call_function( operator.getitem, args=(graph_node, idx) ) @@ -199,6 +207,10 @@ def compile_op(self, node, *inputs, **kwargs): out = graph.output((graph_node,)) if "tensor_meta" not in node.meta: raise ValueError(f"Node {node} does not have tensor_meta") + if len(node.users) != len(graph_node.users): + raise ValueError( + f"Op Node {node} has different number of users({len(graph_node.users)}) from global graph({len(node.users)})" + ) op.compilation_status = OpCompilationStatus.CREATED_GRAPH out.meta["tensor_meta"] = node.meta["tensor_meta"]