diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 6a1f0a51..814c3021 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -263,8 +263,10 @@ def run_gm_op_by_op(self, *inputs): input_index = 0 outputs = [] num_nodes = len(self.gm.graph.nodes) + out_degree = {} for idx, node in enumerate(self.gm.graph.nodes): print(f"Compiling {idx}/{num_nodes}: {node.target}") + out_degree[node] = len(node.users) if node.op == "placeholder": node_to_tensor[node] = inputs[input_index] input_index += 1 @@ -309,6 +311,24 @@ def run_gm_op_by_op(self, *inputs): args = node.args[0] output_tensors = [node_to_tensor[arg] for arg in args] outputs = output_tensors + args_set = set() + for arg in node.args: + if arg in args_set: + continue + args_set.add(arg) + if isinstance(arg, torch.fx.node.Node): + out_degree[arg] -= 1 + if out_degree[arg] == 0 and arg.op != "output": + del node_to_tensor[arg] + out_degree.pop(arg) + # Handle any intermediaries left - might be redundant + intermediates = [ + node + for node, users in out_degree.items() + if users == 0 and node.op != "output" + ] + for node in intermediates: + del node_to_tensor[node] self.compiler_config.save_unique_ops() return outputs