Skip to content

Commit

Permalink
Revised intermediary removal
Browse files Browse the repository at this point in the history
  • Loading branch information
ddilbazTT committed Nov 7, 2024
1 parent c252afb commit 551dc63
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,11 @@ def run_gm_op_by_op(self, *inputs):
input_index = 0
outputs = []
num_nodes = len(self.gm.graph.nodes)
node_to_users = {}
out_degree = {}
for idx, node in enumerate(self.gm.graph.nodes):
print(f"Compiling {idx}/{num_nodes}: {node.target}")
out_degree[node] = len(node.users)
if node.op == "placeholder":
node_to_users[node] = len(node.users)
node_to_tensor[node] = inputs[input_index]
input_index += 1
elif node.op == "get_attr":
Expand All @@ -276,7 +276,6 @@ def run_gm_op_by_op(self, *inputs):
node_to_tensor[node] = buffer[1]
break
elif node.op == "call_function":
node_to_users[node] = len(node.users)
args = []
for arg in node.args:
if isinstance(arg, torch.fx.node.Node):
Expand Down Expand Up @@ -308,34 +307,30 @@ def run_gm_op_by_op(self, *inputs):
tensor = node.target(*args, **node.kwargs)

node_to_tensor[node] = tensor
args_set = set()
for arg in node.args:
if arg in args_set: # skip duplicate args
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
node_to_users[arg] -= 1
elif node.op == "output":
args = node.args[0]
output_tensors = [node_to_tensor[arg] for arg in args]
outputs = output_tensors
args_set = set()
for arg in node.args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
out_degree[arg] -= 1
if out_degree[arg] == 0 and arg.op != "output":
del node_to_tensor[arg]
out_degree.pop(arg)
# Handle any intermediaries left - might be redundant
intermediates = [node for node, users in out_degree.items() if users == 0 and node.op != "output"]
for node in intermediates:
del node_to_tensor[node]

args_set = set()
for arg in args:
if arg in args_set:
continue
args_set.add(arg)
if isinstance(arg, torch.fx.node.Node):
node_to_users[arg] -= 1
for arg in reversed(list(node_to_users.keys())):
if node.name == "output":
continue
if node_to_users[arg] == 0:
del node_to_tensor[arg]
self.gm.graph.erase_node(arg)
self.compiler_config.save_unique_ops()
return outputs



def __call__(self, *inputs):
if self.compiler_config.compile_depth == CompileDepth.EXECUTE:
assert self.binary is not None, "Binary must be set for EXECUTE mode"
Expand Down

0 comments on commit 551dc63

Please sign in to comment.