diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 6a1f0a51..603361fb 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -263,9 +263,11 @@ def run_gm_op_by_op(self, *inputs): input_index = 0 outputs = [] num_nodes = len(self.gm.graph.nodes) + node_to_users = {} for idx, node in enumerate(self.gm.graph.nodes): print(f"Compiling {idx}/{num_nodes}: {node.target}") if node.op == "placeholder": + node_to_users[node] = len(node.users) node_to_tensor[node] = inputs[input_index] input_index += 1 elif node.op == "get_attr": @@ -274,6 +276,7 @@ def run_gm_op_by_op(self, *inputs): node_to_tensor[node] = buffer[1] break elif node.op == "call_function": + node_to_users[node] = len(node.users) args = [] for arg in node.args: if isinstance(arg, torch.fx.node.Node): @@ -305,11 +308,31 @@ def run_gm_op_by_op(self, *inputs): tensor = node.target(*args, **node.kwargs) node_to_tensor[node] = tensor + args_set = set() + for arg in node.args: + if arg in args_set: # skip duplicate args + continue + args_set.add(arg) + if isinstance(arg, torch.fx.node.Node): + node_to_users[arg] -= 1 elif node.op == "output": args = node.args[0] output_tensors = [node_to_tensor[arg] for arg in args] outputs = output_tensors + args_set = set() + for arg in args: + if arg in args_set: + continue + args_set.add(arg) + if isinstance(arg, torch.fx.node.Node): + node_to_users[arg] -= 1 + for arg in reversed(list(node_to_users.keys())): + if node.name == "output": + continue + if node_to_users[arg] == 0: + del node_to_tensor[arg] + self.gm.graph.erase_node(arg) self.compiler_config.save_unique_ops() return outputs