diff --git a/pybuda/test/mlir/mnist/test_inference.py b/pybuda/test/mlir/mnist/test_inference.py index b224e559a..66af4d36f 100644 --- a/pybuda/test/mlir/mnist/test_inference.py +++ b/pybuda/test/mlir/mnist/test_inference.py @@ -2,23 +2,12 @@ # SPDX-License-Identifier: Apache-2.0 -from pybuda._C import DataFormat -from pybuda.config import _get_global_compiler_config import torch -from torch import nn - from .utils import * - import pybuda - def test_mnist_inference(): - compiler_cfg = _get_global_compiler_config() - df = DataFormat.Float16_b - compiler_cfg.default_df_override = df - compiler_cfg.default_accumulate_df = df - - inputs = [torch.rand(1, 784, dtype=torch.bfloat16)] + inputs = [torch.rand(1, 784)] framework_model = MNISTLinear() fw_out = framework_model(*inputs) diff --git a/pybuda/test/mlir/mnist/utils.py b/pybuda/test/mlir/mnist/utils.py index 260840e09..22807ae20 100644 --- a/pybuda/test/mlir/mnist/utils.py +++ b/pybuda/test/mlir/mnist/utils.py @@ -16,11 +16,11 @@ class MNISTLinear(nn.Module): def __init__(self, input_size=784, output_size=10, hidden_size=256): super(MNISTLinear, self).__init__() - self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16) - self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16)) + self.l1 = nn.Linear(input_size, hidden_size, bias=False) + self.b1 = nn.Parameter(torch.ones(1, hidden_size)) self.relu = nn.ReLU() - self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16)) - self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16) + self.b2 = nn.Parameter(torch.ones(1, output_size)) + self.l2 = nn.Linear(hidden_size, output_size, bias=False) def forward(self, x): x = self.l1(x) @@ -29,7 +29,7 @@ def forward(self, x): x = self.l2(x) x = x + self.b2 - return nn.functional.softmax(x, dtype=torch.bfloat16) + return nn.functional.softmax(x) diff --git a/third_party/tt-mlir b/third_party/tt-mlir index 3b627dd87..83c705cb6 160000 --- a/third_party/tt-mlir +++ b/third_party/tt-mlir @@ -1 +1 @@ -Subproject commit 3b627dd87d50b5a8ecef6e85802430376f411a70 +Subproject commit 83c705cb61729f9b23ad3e1c1839023eea259711