Skip to content

Commit

Permalink
Updating Metal version and removing data type overrides from MNIST model
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Aug 9, 2024
1 parent 48db3d3 commit d982513
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 18 deletions.
13 changes: 1 addition & 12 deletions pybuda/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,12 @@

# SPDX-License-Identifier: Apache-2.0

from pybuda._C import DataFormat
from pybuda.config import _get_global_compiler_config
import torch
from torch import nn

from .utils import *

import pybuda


def test_mnist_inference():
compiler_cfg = _get_global_compiler_config()
df = DataFormat.Float16_b
compiler_cfg.default_df_override = df
compiler_cfg.default_accumulate_df = df

inputs = [torch.rand(1, 784, dtype=torch.bfloat16)]
inputs = [torch.rand(1, 784)]

framework_model = MNISTLinear()
fw_out = framework_model(*inputs)
Expand Down
10 changes: 5 additions & 5 deletions pybuda/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
class MNISTLinear(nn.Module):
def __init__(self, input_size=784, output_size=10, hidden_size=256):
super(MNISTLinear, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16)
self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16))
self.l1 = nn.Linear(input_size, hidden_size, bias=False)
self.b1 = nn.Parameter(torch.ones(1, hidden_size))
self.relu = nn.ReLU()
self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16))
self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16)
self.b2 = nn.Parameter(torch.ones(1, output_size))
self.l2 = nn.Linear(hidden_size, output_size, bias=False)

def forward(self, x):
x = self.l1(x)
Expand All @@ -29,7 +29,7 @@ def forward(self, x):
x = self.l2(x)
x = x + self.b2

return nn.functional.softmax(x, dtype=torch.bfloat16)
return nn.functional.softmax(x)



Expand Down

0 comments on commit d982513

Please sign in to comment.