-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'nvukobrat/mnist_placeholder' into 'main'
[Training] MNIST Linear: Inference sample using torch.compile and training... See merge request nvukobrat/pybuda-mlir-integration!8
- Loading branch information
Showing
4 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import torch | ||
from torch import nn | ||
|
||
from .utils import * | ||
|
||
|
||
def test_mnist_inference(): | ||
inputs = [torch.rand(1, 784)] | ||
|
||
framework_model = MNISTLinear() | ||
fw_out = framework_model(*inputs) | ||
|
||
compiled_model = torch.compile(framework_model.to("tt"), backend="tt") | ||
co_out = compiled_model(*[i.to("tt") for i in inputs]) | ||
|
||
co_out = [co.to("cpu") for co in co_out] | ||
assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
from torch import nn | ||
|
||
import pybuda | ||
from .utils import * | ||
|
||
def test_mnist_training(): | ||
torch.manual_seed(0) | ||
|
||
# Config | ||
num_epochs = 9 | ||
batch_size = 64 | ||
learning_rate = 0.005 | ||
|
||
# Load dataset | ||
test_loader, train_loader = load_dataset(batch_size) | ||
|
||
# Load TensorBoard writer (for logging) | ||
writer = load_tb_writer() | ||
|
||
# Define model and instruct it to compile and run on TT device | ||
framework_model = MNISTLinear() | ||
tt_model = pybuda.compile(framework_model) | ||
tt_model.to("tt") | ||
|
||
# Create a torch loss and leave on CPU | ||
loss = torch.nn.L1Loss() | ||
|
||
# Define optimizer and instruct it to compile and run on TT device | ||
framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate) | ||
tt_optimizer = pybuda.compile(framework_optimizer) | ||
tt_optimizer.to("tt") | ||
|
||
for epoch_idx in range(num_epochs): | ||
for batch_idx, (data, target) in enumerate(train_loader): | ||
# Put inputs on device | ||
data = data.to("tt") | ||
|
||
# Create target tensor and leave on CPU | ||
target = nn.functional.one_hot(target, num_classes=10).float() | ||
|
||
# Reset gradients (every batch) | ||
tt_optimizer.zero_grad() | ||
|
||
# Forward pass (prediction) on device | ||
pred = tt_model(data) | ||
|
||
# Pull output back to CPU | ||
pred = pred.to("cpu") | ||
|
||
# Compute loss on CPU | ||
loss = tt_loss(pred, target) | ||
|
||
# RUn backward pass on device | ||
loss.backward() | ||
|
||
# Adjust weights (on device) | ||
tt_optimizer.step() | ||
|
||
# Log gradients | ||
for name, param in tt_model.named_parameters(): | ||
writer.add_histogram(f"{name}.grad", param.grad, batch_idx) | ||
|
||
# Log loss | ||
writer.add_scalar("Loss", loss.item(), batch_idx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from datetime import datetime | ||
|
||
import torch | ||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
import torchvision.transforms as transforms | ||
from torch.utils.tensorboard import SummaryWriter | ||
from torchvision.datasets import MNIST as mnist_dataset | ||
|
||
|
||
# Model definition | ||
class MNISTLinear(nn.Module): | ||
def __init__(self, input_size=784, output_size=10, hidden_size=256): | ||
super(MNISTLinear, self).__init__() | ||
self.l1 = nn.Linear(input_size, hidden_size) | ||
self.relu = nn.ReLU() | ||
self.l2 = nn.Linear(hidden_size, output_size) | ||
|
||
def forward(self, x): | ||
x = self.l1(x) | ||
x = self.relu(x) | ||
x = self.l2(x) | ||
|
||
return nn.functional.softmax(x) | ||
|
||
|
||
def load_tb_writer(): | ||
""" | ||
Load TensorBoard writer for logging | ||
""" | ||
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | ||
log_dir = f"runs/gradient_visualization/{current_time}/" | ||
writer = SummaryWriter(log_dir) | ||
|
||
return writer | ||
|
||
|
||
def load_dataset(batch_size): | ||
""" | ||
Load and normalize MNIST dataset | ||
""" | ||
transform = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)), # Mean and std for MNIST | ||
transforms.Lambda(lambda x: x.view(-1)), # Flatten image | ||
] | ||
) | ||
|
||
train_dataset = mnist_dataset( | ||
root="./data", train=True, download=True, transform=transform | ||
) | ||
test_dataset = mnist_dataset( | ||
root="./data", train=False, download=True, transform=transform | ||
) | ||
|
||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False) | ||
|
||
return test_loader, train_loader |