diff --git a/pybuda/test/mlir/mnist/__init__.py b/pybuda/test/mlir/mnist/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pybuda/test/mlir/mnist/test_inference.py b/pybuda/test/mlir/mnist/test_inference.py new file mode 100644 index 000000000..0722e626e --- /dev/null +++ b/pybuda/test/mlir/mnist/test_inference.py @@ -0,0 +1,17 @@ +import torch +from torch import nn + +from .utils import * + + +def test_mnist_inference(): + inputs = [torch.rand(1, 784)] + + framework_model = MNISTLinear() + fw_out = framework_model(*inputs) + + compiled_model = torch.compile(framework_model.to("tt"), backend="tt") + co_out = compiled_model(*[i.to("tt") for i in inputs]) + + co_out = [co.to("cpu") for co in co_out] + assert [torch.allclose(fo, co) for fo, co in zip(fw_out, co_out)] diff --git a/pybuda/test/mlir/mnist/test_training.py b/pybuda/test/mlir/mnist/test_training.py new file mode 100644 index 000000000..0a03792e4 --- /dev/null +++ b/pybuda/test/mlir/mnist/test_training.py @@ -0,0 +1,65 @@ +import torch +from torch import nn + +import pybuda +from .utils import * + +def test_mnist_training(): + torch.manual_seed(0) + + # Config + num_epochs = 9 + batch_size = 64 + learning_rate = 0.005 + + # Load dataset + test_loader, train_loader = load_dataset(batch_size) + + # Load TensorBoard writer (for logging) + writer = load_tb_writer() + + # Define model and instruct it to compile and run on TT device + framework_model = MNISTLinear() + tt_model = pybuda.compile(framework_model) + tt_model.to("tt") + + # Create a torch loss and leave on CPU + loss = torch.nn.L1Loss() + + # Define optimizer and instruct it to compile and run on TT device + framework_optimizer = torch.optim.SGD(framework_model.parameters(), lr=learning_rate) + tt_optimizer = pybuda.compile(framework_optimizer) + tt_optimizer.to("tt") + + for epoch_idx in range(num_epochs): + for batch_idx, (data, target) in enumerate(train_loader): + # Put inputs on device + data = data.to("tt") + + # Create target tensor and leave on CPU + target = nn.functional.one_hot(target, num_classes=10).float() + + # Reset gradients (every batch) + tt_optimizer.zero_grad() + + # Forward pass (prediction) on device + pred = tt_model(data) + + # Pull output back to CPU + pred = pred.to("cpu") + + # Compute loss on CPU + loss = tt_loss(pred, target) + + # RUn backward pass on device + loss.backward() + + # Adjust weights (on device) + tt_optimizer.step() + + # Log gradients + for name, param in tt_model.named_parameters(): + writer.add_histogram(f"{name}.grad", param.grad, batch_idx) + + # Log loss + writer.add_scalar("Loss", loss.item(), batch_idx) diff --git a/pybuda/test/mlir/mnist/utils.py b/pybuda/test/mlir/mnist/utils.py new file mode 100644 index 000000000..68c679ef5 --- /dev/null +++ b/pybuda/test/mlir/mnist/utils.py @@ -0,0 +1,60 @@ +from datetime import datetime + +import torch +from torch import nn +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from torch.utils.tensorboard import SummaryWriter +from torchvision.datasets import MNIST as mnist_dataset + + +# Model definition +class MNISTLinear(nn.Module): + def __init__(self, input_size=784, output_size=10, hidden_size=256): + super(MNISTLinear, self).__init__() + self.l1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.l2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.l1(x) + x = self.relu(x) + x = self.l2(x) + + return nn.functional.softmax(x) + + +def load_tb_writer(): + """ + Load TensorBoard writer for logging + """ + current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + log_dir = f"runs/gradient_visualization/{current_time}/" + writer = SummaryWriter(log_dir) + + return writer + + +def load_dataset(batch_size): + """ + Load and normalize MNIST dataset + """ + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), # Mean and std for MNIST + transforms.Lambda(lambda x: x.view(-1)), # Flatten image + ] + ) + + train_dataset = mnist_dataset( + root="./data", train=True, download=True, transform=transform + ) + test_dataset = mnist_dataset( + root="./data", train=False, download=True, transform=transform + ) + + train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False) + + return test_loader, train_loader