diff --git a/tt_torch/dynamo/backend.py b/tt_torch/dynamo/backend.py index 603361fb..05d3f0a8 100644 --- a/tt_torch/dynamo/backend.py +++ b/tt_torch/dynamo/backend.py @@ -263,11 +263,11 @@ def run_gm_op_by_op(self, *inputs): input_index = 0 outputs = [] num_nodes = len(self.gm.graph.nodes) - node_to_users = {} + out_degree = {} for idx, node in enumerate(self.gm.graph.nodes): print(f"Compiling {idx}/{num_nodes}: {node.target}") + out_degree[node] = len(node.users) if node.op == "placeholder": - node_to_users[node] = len(node.users) node_to_tensor[node] = inputs[input_index] input_index += 1 elif node.op == "get_attr": @@ -276,7 +276,6 @@ def run_gm_op_by_op(self, *inputs): node_to_tensor[node] = buffer[1] break elif node.op == "call_function": - node_to_users[node] = len(node.users) args = [] for arg in node.args: if isinstance(arg, torch.fx.node.Node): @@ -308,34 +307,30 @@ def run_gm_op_by_op(self, *inputs): tensor = node.target(*args, **node.kwargs) node_to_tensor[node] = tensor - args_set = set() - for arg in node.args: - if arg in args_set: # skip duplicate args - continue - args_set.add(arg) - if isinstance(arg, torch.fx.node.Node): - node_to_users[arg] -= 1 elif node.op == "output": args = node.args[0] output_tensors = [node_to_tensor[arg] for arg in args] outputs = output_tensors + args_set = set() + for arg in node.args: + if arg in args_set: + continue + args_set.add(arg) + if isinstance(arg, torch.fx.node.Node): + out_degree[arg] -= 1 + if out_degree[arg] == 0 and arg.op != "output": + del node_to_tensor[arg] + out_degree.pop(arg) + # Handle any intermediaries left - might be redundant + intermediates = [node for node, users in out_degree.items() if users == 0 and node.op != "output"] + for node in intermediates: + del node_to_tensor[node] - args_set = set() - for arg in args: - if arg in args_set: - continue - args_set.add(arg) - if isinstance(arg, torch.fx.node.Node): - node_to_users[arg] -= 1 - for arg in reversed(list(node_to_users.keys())): - if node.name == "output": - continue - if node_to_users[arg] == 0: - del node_to_tensor[arg] - self.gm.graph.erase_node(arg) self.compiler_config.save_unique_ops() return outputs + + def __call__(self, *inputs): if self.compiler_config.compile_depth == CompileDepth.EXECUTE: assert self.binary is not None, "Binary must be set for EXECUTE mode"