Skip to content

Commit

Permalink
Don't clone to reduce memory usage, continue on a failed test
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksKnezevic committed Dec 16, 2024
1 parent ae3cb55 commit 66fff62
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run-model-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:
for testname in ${{ matrix.build.test_names }}; do
testname=${testname//,/}
echo "running test: $testname"
pytest -v tests/models/$testname
pytest -svv tests/models/$testname || true
done
- name: Tar results
Expand Down
4 changes: 2 additions & 2 deletions tt_torch/dynamo/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ def __call__(self, *inputs):


def _base_backend(gm: torch.fx.GraphModule, example_inputs, compiler_config):
gm, graph_constants = pass_pipeline(gm, example_inputs, compiler_config)
gm.graph.print_tabular()
with torch.no_grad():
gm, graph_constants = pass_pipeline(gm, example_inputs, compiler_config)
executor = Executor(gm, graph_constants, compiler_config)
if compiler_config.compile_depth in (
CompileDepth.EXECUTE_OP_BY_OP,
Expand Down
36 changes: 33 additions & 3 deletions tt_torch/dynamo/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,40 @@ def bypass_redundant_getitem(gm):
return gm


def run_folding(gm):
# If there's no const subgraph module or attr output names to use, return
# early as there is no const folding to perform.
if gm.const_subgraph_module is None or gm.fx_const_folded_attrs_name is None:
return

assert not gm.has_folding_been_run
gm.has_folding_been_run = True

# Actually run const folding subgraph. Note that single attr const fold
# subgraphs output a single Tensor while multiple outputs are returned as
# Tuple[Tensor,].
folded_attrs = gm.const_subgraph_module()

def _create_param(i):
return torch.nn.Parameter(
i.detach()
if not isinstance(i, int)
else torch.Tensor([i]).to(device=gm.device_for_folded_attrs),
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False,
)

params = (
torch.nn.ParameterList([_create_param(i) for i in folded_attrs])
if isinstance(folded_attrs, tuple)
else _create_param(folded_attrs)
)
setattr(gm, gm.fx_const_folded_attrs_name, params)


def constant_fold(gm, example_inputs):
gm = const_fold.split_const_subgraphs(gm)
gm.run_folding()
print(f"Constant folding done")
graph_constants = {}

for node in gm.graph.nodes:
Expand Down Expand Up @@ -124,12 +155,12 @@ def inline_parameters(gm):
placeholders = {}
for node in gm.graph.nodes:
if node.op == "get_attr":
assert node.target in gm._parameters
assert hasattr(gm, node.target), f"Parameter {node.target} not found"
gm.graph.inserting_before(node)
if node.target not in placeholders:
placeholder = gm.graph.placeholder(node.target)
placeholders[node.target] = placeholder
parameters[node.target] = gm._parameters[node.target].data
parameters[node.target] = getattr(gm, node.target).data
else:
placeholder = placeholders[node.target]
node.replace_all_uses_with(placeholder)
Expand Down Expand Up @@ -157,7 +188,6 @@ def pass_pipeline(gm: torch.fx.GraphModule, example_inputs, compiler_config):
else:
constants = []
gm = bypass_redundant_getitem(gm)
gm.graph.print_tabular()
gm, parameters = inline_parameters(gm)
constant_inputs = order_constant_inputs(gm, parameters, constants)
reduce_graph(gm)
Expand Down

0 comments on commit 66fff62

Please sign in to comment.