Skip to content

Commit

Permalink
Merge branch 'nvukobrat/mnist_placeholder' into 'main'
Browse files Browse the repository at this point in the history
[Training] MNIST Linear: Inference sample using torch.compile and training...

See merge request nvukobrat/pybuda-mlir-integration!8
  • Loading branch information
nvukobratTT committed Jul 18, 2024
2 parents bd15c9b + 20fe913 commit 179bbda
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 0 deletions.
Empty file.
17 changes: 17 additions & 0 deletions pybuda/test/mlir/mnist/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from torch import nn

from .utils import *


def test_mnist_inference():
inputs = [torch.rand(1, 784)]

framework_model = MNISTLinear()
fw_out = framework_model(*inputs)

compiled_model = torch.compile(framework_model.to("tt"), backend="tt")
co_out = compiled_model(*[i.to("tt") for i in inputs])

co_out = [co.to("cpu") for co in co_out]
assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)]
65 changes: 65 additions & 0 deletions pybuda/test/mlir/mnist/test_training.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import torch
from torch import nn

import pybuda
from .utils import *

def test_mnist_training():
torch.manual_seed(0)

# Config
num_epochs = 9
batch_size = 64
learning_rate = 0.005

# Load dataset
test_loader, train_loader = load_dataset(batch_size)

# Load TensorBoard writer (for logging)
writer = load_tb_writer()

# Define model and instruct it to compile and run on TT device
framework_model = MNISTLinear()
tt_model = pybuda.compile(framework_model)
tt_model.to("tt")

# Create a torch loss and leave on CPU
loss = torch.nn.L1Loss()

# Define optimizer and instruct it to compile and run on TT device
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate)
tt_optimizer = pybuda.compile(framework_optimizer)
tt_optimizer.to("tt")

for epoch_idx in range(num_epochs):
for batch_idx, (data, target) in enumerate(train_loader):
# Put inputs on device
data = data.to("tt")

# Create target tensor and leave on CPU
target = nn.functional.one_hot(target, num_classes=10).float()

# Reset gradients (every batch)
tt_optimizer.zero_grad()

# Forward pass (prediction) on device
pred = tt_model(data)

# Pull output back to CPU
pred = pred.to("cpu")

# Compute loss on CPU
loss = tt_loss(pred, target)

# RUn backward pass on device
loss.backward()

# Adjust weights (on device)
tt_optimizer.step()

# Log gradients
for name, param in tt_model.named_parameters():
writer.add_histogram(f"{name}.grad", param.grad, batch_idx)

# Log loss
writer.add_scalar("Loss", loss.item(), batch_idx)
60 changes: 60 additions & 0 deletions pybuda/test/mlir/mnist/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from datetime import datetime

import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import MNIST as mnist_dataset


# Model definition
class MNISTLinear(nn.Module):
def __init__(self, input_size=784, output_size=10, hidden_size=256):
super(MNISTLinear, self).__init__()
self.l1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.l2 = nn.Linear(hidden_size, output_size)

def forward(self, x):
x = self.l1(x)
x = self.relu(x)
x = self.l2(x)

return nn.functional.softmax(x)


def load_tb_writer():
"""
Load TensorBoard writer for logging
"""
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = f"runs/gradient_visualization/{current_time}/"
writer = SummaryWriter(log_dir)

return writer


def load_dataset(batch_size):
"""
Load and normalize MNIST dataset
"""
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)), # Mean and std for MNIST
transforms.Lambda(lambda x: x.view(-1)), # Flatten image
]
)

train_dataset = mnist_dataset(
root="./data", train=True, download=True, transform=transform
)
test_dataset = mnist_dataset(
root="./data", train=False, download=True, transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

return test_loader, train_loader

0 comments on commit 179bbda

Please sign in to comment.