Skip to content

Commit

Permalink
Updating Metal version and removing data type overrides from MNIST model
Browse files Browse the repository at this point in the history
  • Loading branch information
sdjordjevicTT committed Aug 9, 2024
1 parent 48db3d3 commit 8b7a37b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
7 changes: 1 addition & 6 deletions pybuda/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@


def test_mnist_inference():
compiler_cfg = _get_global_compiler_config()
df = DataFormat.Float16_b
compiler_cfg.default_df_override = df
compiler_cfg.default_accumulate_df = df

inputs = [torch.rand(1, 784, dtype=torch.bfloat16)]
inputs = [torch.rand(1, 784)]

framework_model = MNISTLinear()
fw_out = framework_model(*inputs)
Expand Down
10 changes: 5 additions & 5 deletions pybuda/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
class MNISTLinear(nn.Module):
def __init__(self, input_size=784, output_size=10, hidden_size=256):
super(MNISTLinear, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size, bias=False, dtype=torch.bfloat16)
self.b1 = nn.Parameter(torch.ones(1, hidden_size, dtype=torch.bfloat16))
self.l1 = nn.Linear(input_size, hidden_size, bias=False)
self.b1 = nn.Parameter(torch.ones(1, hidden_size))
self.relu = nn.ReLU()
self.b2 = nn.Parameter(torch.ones(1, output_size, dtype=torch.bfloat16))
self.l2 = nn.Linear(hidden_size, output_size, bias=False, dtype=torch.bfloat16)
self.b2 = nn.Parameter(torch.ones(1, output_size))
self.l2 = nn.Linear(hidden_size, output_size, bias=False)

def forward(self, x):
x = self.l1(x)
Expand All @@ -29,7 +29,7 @@ def forward(self, x):
x = self.l2(x)
x = x + self.b2

return nn.functional.softmax(x, dtype=torch.bfloat16)
return nn.functional.softmax(x)



Expand Down

0 comments on commit 8b7a37b

Please sign in to comment.