-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Bug] Solves issue when certain inputs/constants aren't properly decl…
…ared during MLIR emit Previously, MLIR emit was hiting edge cases when declaring constant inputs. More precisely, they were mostly skipped. This fix redefines how inputs are recognized (using kInput node type), and properly distinguish regular and constant inputs vs model parameters. Issue uncovered during #112 op bringup (reciprocal). At the same time, PR related to #112 is testing this case. Additionally, inference and training MNIST are also covering this feature for functionality. Fixes #201
- Loading branch information
1 parent
8bcb735
commit 403db1c
Showing
8 changed files
with
147 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import os | ||
import pytest | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
|
||
import forge | ||
from forge.op.eval.common import compare_with_golden_pcc | ||
|
||
def test_multiple_inputs(): | ||
class MultipleInputs(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b, c): | ||
return a + b + c | ||
|
||
inputs = [torch.rand(1, 32, 32), torch.rand(1, 32, 32), torch.rand(1, 32, 32)] | ||
|
||
framework_model = MultipleInputs() | ||
fw_out = framework_model(*inputs) | ||
|
||
compiled_model = forge.compile(framework_model, sample_inputs=inputs) | ||
co_out = compiled_model(*inputs) | ||
|
||
co_out = [co.to("cpu") for co in co_out] | ||
assert [compare_with_golden_pcc(golden=fo, calculated=co, pcc=0.99) for fo, co in zip(fw_out, co_out)] | ||
|
||
|
||
@pytest.mark.parametrize("a_shape, b_shape, c_shape", [ | ||
((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)), | ||
((1, 1, 64, 32), (1, 1, 32, 128), (1, 1, 128, 64)), | ||
((1, 1, 128, 64), (1, 1, 64, 256), (1, 1, 256, 128)), | ||
((1, 1, 256, 128), (1, 1, 128, 512), (1, 1, 512, 256)) | ||
]) | ||
def test_input_order(a_shape, b_shape, c_shape): | ||
class InputOrder(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, a, b, c): | ||
x = torch.matmul(a, b) | ||
x = torch.matmul(x, c) | ||
|
||
return x | ||
|
||
a = torch.rand(*a_shape) | ||
b = torch.rand(*b_shape) | ||
c = torch.rand(*c_shape) | ||
|
||
framework_model = InputOrder() | ||
fw_out = framework_model(a, b, c) | ||
|
||
compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c]) | ||
co_out = compiled_model(a, b, c) | ||
|
||
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out, pcc=0.99) | ||
|
||
|
||
@pytest.mark.parametrize("a_shape, b_shape, c_shape", [ | ||
((1, 1, 32, 64), (1, 1, 64, 128), (1, 1, 128, 32)), | ||
]) | ||
def test_input_order_with_constants(a_shape, b_shape, c_shape): | ||
class InputOrderWithConstants(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.const1 = torch.rand(1, 1, 32, 32) | ||
self.const2 = torch.rand(1, 1, 32, 32) | ||
|
||
def forward(self, a, b, c): | ||
x = torch.matmul(a, b) | ||
x = torch.matmul(x, c) | ||
x = x + self.const1 | ||
x = x * self.const2 | ||
return x | ||
|
||
a = torch.rand(*a_shape) | ||
b = torch.rand(*b_shape) | ||
c = torch.rand(*c_shape) | ||
|
||
framework_model = InputOrderWithConstants() | ||
fw_out = framework_model(a, b, c) | ||
|
||
compiled_model = forge.compile(framework_model, sample_inputs=[a, b, c]) | ||
co_out = compiled_model(a, b, c) | ||
|
||
assert compare_with_golden_pcc(golden=fw_out, calculated=co_out[0][0], pcc=0.99) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters