Skip to content

Commit

Permalink
Added logic for deleting nodes - need review
Browse files Browse the repository at this point in the history
  • Loading branch information
ddilbazTT committed Nov 5, 2024
1 parent 7706bad commit 2f0ee6e
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,9 +263,11 @@ def run_gm_op_by_op(self, *inputs):
input_index = 0
outputs = []
num_nodes = len(self.gm.graph.nodes)
node_to_users = {}
for idx, node in enumerate(self.gm.graph.nodes):
print(f"Compiling {idx}/{num_nodes}: {node.target}")
if node.op == "placeholder":
node_to_users[node] = len(node.users)
node_to_tensor[node] = inputs[input_index]
input_index += 1
elif node.op == "get_attr":
Expand All @@ -274,6 +276,7 @@ def run_gm_op_by_op(self, *inputs):
node_to_tensor[node] = buffer[1]
break
elif node.op == "call_function":
node_to_users[node] = len(node.users)
args = []
for arg in node.args:
if isinstance(arg, torch.fx.node.Node):
Expand Down Expand Up @@ -305,11 +308,31 @@ def run_gm_op_by_op(self, *inputs):
tensor = node.target(*args, **node.kwargs)

node_to_tensor[node] = tensor
args_set = set()
for arg in node.args:
if arg in args_set: # skip duplicate args
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
node_to_users[arg] -= 1
elif node.op == "output":
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
outputs = output_tensors


args_set = set()
for arg in args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
node_to_users[arg] -= 1
for arg in reversed(list(node_to_users.keys())):
if node.name == "output":
continue
if node_to_users[arg] == 0:
del node_to_tensor[arg]
self.gm.graph.erase_node(arg)
self.compiler_config.save_unique_ops()
return outputs

Expand Down

0 comments on commit 2f0ee6e

Please sign in to comment.