Skip to content

Commit

Permalink
Added logic for deleting nodes - need review
Browse files Browse the repository at this point in the history
  • Loading branch information
ddilbazTT committed Nov 7, 2024
1 parent 7706bad commit b2f2dc9
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ def run_gm_op_by_op(self, *inputs):
input_index = 0
outputs = []
num_nodes = len(self.gm.graph.nodes)
out_degree = {}
for idx, node in enumerate(self.gm.graph.nodes):
print(f"Compiling {idx}/{num_nodes}: {node.target}")
out_degree[node] = len(node.users)
if node.op == "placeholder":
node_to_tensor[node] = inputs[input_index]
input_index += 1
Expand Down Expand Up @@ -309,6 +311,24 @@ def run_gm_op_by_op(self, *inputs):
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
outputs = output_tensors
args_set = set()
for arg in node.args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
out_degree[arg] -= 1
if out_degree[arg] == 0 and arg.op != "output":
del node_to_tensor[arg]
out_degree.pop(arg)
# Handle any intermediaries left - might be redundant
intermediates = [
node
for node, users in out_degree.items()
if users == 0 and node.op != "output"
]
for node in intermediates:
del node_to_tensor[node]

self.compiler_config.save_unique_ops()
return outputs
Expand Down

0 comments on commit b2f2dc9

Please sign in to comment.