Skip to content

Commit

Permalink
Merge pull request #20 from tenstorrent/ddilbaz/delete_nodes
Browse files Browse the repository at this point in the history
Added logic for deleting intermediate tensors
  • Loading branch information
ddilbazTT authored Nov 7, 2024
2 parents 7706bad + d244633 commit 6503c70
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ def run_gm_op_by_op(self, *inputs):
input_index = 0
outputs = []
num_nodes = len(self.gm.graph.nodes)
out_degree = {}
for idx, node in enumerate(self.gm.graph.nodes):
print(f"Compiling {idx}/{num_nodes}: {node.target}")
out_degree[node] = len(node.users)
if node.op == "placeholder":
node_to_tensor[node] = inputs[input_index]
input_index += 1
Expand Down Expand Up @@ -309,6 +311,16 @@ def run_gm_op_by_op(self, *inputs):
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
outputs = output_tensors
args_set = set()
for arg in node.args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
out_degree[arg] -= 1
if out_degree[arg] == 0 and arg.op != "output":
del node_to_tensor[arg]
out_degree.pop(arg)

self.compiler_config.save_unique_ops()
return outputs
Expand Down

0 comments on commit 6503c70

Please sign in to comment.