-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added pytorch and pybuda mnist training scripts.
- Loading branch information
1 parent
dc92b2b
commit 1fb823b
Showing
3 changed files
with
211 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from torchvision import datasets, transforms | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
import pybuda | ||
from pybuda import ( | ||
CPUDevice, | ||
PyTorchModule, | ||
) | ||
from utils import ( | ||
MNISTLinear, | ||
Identity, | ||
load_tb_writer, | ||
load_dataset, | ||
) | ||
from pybuda.config import _get_global_compiler_config | ||
|
||
class FeedForward(torch.nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size): | ||
super(FeedForward, self).__init__() | ||
self.fc1 = torch.nn.Linear(input_size, hidden_size) | ||
self.relu = torch.nn.ReLU() | ||
self.fc2 = torch.nn.Linear(hidden_size, output_size) | ||
|
||
def forward(self, x): | ||
x = self.fc1(x) | ||
x = self.relu(x) | ||
x = self.fc2(x) | ||
return x | ||
|
||
def train(loss_on_cpu=True): | ||
torch.manual_seed(777) | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)), | ||
transforms.Lambda(lambda x: x.view(-1)) | ||
]) | ||
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) | ||
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) | ||
|
||
writer = SummaryWriter() | ||
|
||
num_epochs = 2 | ||
input_size = 784 | ||
hidden_size = 256 | ||
output_size = 10 | ||
batch_size = 3 | ||
learning_rate = 0.001 | ||
sequential = True | ||
|
||
framework_model = FeedForward(input_size, hidden_size, output_size) | ||
tt_model = pybuda.PyTorchModule(f"mnist_linear_{batch_size}", framework_model) | ||
tt_optimizer = pybuda.optimizers.SGD( | ||
learning_rate=learning_rate, device_params=True | ||
) | ||
tt0 = pybuda.TTDevice("tt0", module=tt_model, optimizer=tt_optimizer) | ||
|
||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | ||
# Dataset sample input | ||
first_sample = test_loader.dataset[0] | ||
sample_input = (first_sample[0].repeat(1, batch_size, 1),) | ||
sample_target = ( | ||
torch.nn.functional.one_hot(torch.tensor(first_sample[1]), num_classes=output_size) | ||
.float() | ||
.repeat(1, batch_size, 1) | ||
) | ||
|
||
if loss_on_cpu: | ||
cpu0 = CPUDevice("cpu0", module=PyTorchModule("identity", Identity())) | ||
cpu0.place_loss_module(pybuda.PyTorchModule(f"loss_{batch_size}", torch.nn.CrossEntropyLoss())) | ||
else: | ||
tt_loss = pybuda.PyTorchModule(f"loss_{batch_size}", torch.nn.CrossEntropyLoss()) | ||
tt0.place_loss_module(tt_loss) | ||
|
||
compiler_cfg = _get_global_compiler_config() | ||
compiler_cfg.enable_auto_fusing = False | ||
|
||
if not loss_on_cpu: | ||
sample_target = (sample_target,) | ||
|
||
checkpoint_queue = pybuda.initialize_pipeline( | ||
training=True, | ||
sample_inputs=sample_input, | ||
sample_targets=sample_target, | ||
_sequential=sequential, | ||
) | ||
|
||
best_accuracy = 0.0 | ||
best_checkpoint = None | ||
|
||
for epoch in range(num_epochs): | ||
for batch_idx, (images, labels) in enumerate(train_loader): | ||
|
||
images = (images.unsqueeze(0),) | ||
tt0.push_to_inputs(images) | ||
|
||
targets = ( | ||
torch.nn.functional.one_hot(labels, num_classes=output_size) | ||
.float() | ||
.unsqueeze(0) | ||
) | ||
if loss_on_cpu: | ||
cpu0.push_to_target_inputs(targets) | ||
else: | ||
tt0.push_to_target_inputs(targets) | ||
|
||
pybuda.run_forward(input_count=1, _sequential=sequential) | ||
pybuda.run_backward(input_count=1, zero_grad=True, _sequential=sequential) | ||
pybuda.run_optimizer(checkpoint=True, _sequential=sequential) | ||
|
||
loss_q = pybuda.run.get_loss_queue() | ||
|
||
step = 0 | ||
loss = loss_q.get()[0] | ||
print(loss) | ||
# while not loss_q.empty(): | ||
# if loss_on_cpu: | ||
# writer.add_scalar("Loss/PyBuda/overfit", loss_q.get()[0], step) | ||
# else: | ||
# writer.add_scalar("Loss/PyBuda/overfit", loss_q.get()[0].value()[0], step) | ||
# step += 1 | ||
|
||
writer.close() | ||
|
||
if __name__ == "__main__": | ||
train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
from torchvision import datasets, transforms | ||
from torch.utils.tensorboard import SummaryWriter | ||
|
||
class FeedForward(torch.nn.Module): | ||
def __init__(self, input_size, hidden_size, output_size): | ||
super(FeedForward, self).__init__() | ||
self.fc1 = torch.nn.Linear(input_size, hidden_size) | ||
self.relu = torch.nn.ReLU() | ||
self.fc2 = torch.nn.Linear(hidden_size, output_size) | ||
|
||
def forward(self, x): | ||
x = self.fc1(x) | ||
x = self.relu(x) | ||
x = self.fc2(x) | ||
return x | ||
|
||
def train(): | ||
torch.manual_seed(777) | ||
transform=transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.1307,), (0.3081,)), | ||
transforms.Lambda(lambda x: x.view(-1)) | ||
]) | ||
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) | ||
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True) | ||
|
||
writer = SummaryWriter() | ||
|
||
num_epochs = 10 | ||
input_size = 784 | ||
hidden_size = 256 | ||
output_size = 10 | ||
model = FeedForward(input_size, hidden_size, output_size) | ||
|
||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) | ||
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False) | ||
|
||
best_accuracy = 0.0 | ||
best_checkpoint = None | ||
|
||
for epoch in range(num_epochs): | ||
for batch_idx, (images, labels) in enumerate(train_loader): | ||
outputs = model(images) | ||
loss = torch.nn.CrossEntropyLoss()(outputs, labels) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
if (batch_idx+1) % 100 == 0: | ||
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}') | ||
writer.add_scalar('Loss/train', loss.item(), epoch * len(train_loader) + batch_idx) | ||
|
||
total_correct = 0 | ||
total_samples = 0 | ||
with torch.no_grad(): | ||
for images, labels in test_loader: | ||
outputs = model(images) | ||
_, predicted = torch.max(outputs, dim=1) | ||
total_samples += labels.size(0) | ||
total_correct += (predicted == labels).sum().item() | ||
|
||
accuracy = 100.0 * total_correct / total_samples | ||
print(f'Epoch [{epoch+1}/{num_epochs}], Validation Accuracy: {accuracy:.2f}%') | ||
|
||
if accuracy > best_accuracy: | ||
best_accuracy = accuracy | ||
best_checkpoint = { | ||
'epoch': epoch, | ||
'model_state_dict': model.state_dict(), | ||
'optimizer_state_dict': optimizer.state_dict(), | ||
'accuracy': accuracy | ||
} | ||
|
||
if best_checkpoint is not None: | ||
model.load_state_dict(best_checkpoint['model_state_dict']) | ||
optimizer.load_state_dict(best_checkpoint['optimizer_state_dict']) | ||
print(f'Reverted to checkpoint with highest validation accuracy: {best_checkpoint["accuracy"]:.2f}%') | ||
|
||
writer.close() | ||
|
||
if __name__ == "__main__": | ||
train() |
File renamed without changes.