forked from ttchengab/MnistGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainGAN.py
102 lines (84 loc) · 3.17 KB
/
trainGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt
"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
"""
Hyperparameter settings
"""
epochs = 150
lr = 2e-4
batch_size = 64
loss = nn.BCELoss()
# Model
G = generator().to(device)
D = discriminator().to(device)
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
"""
Image transformation and dataloader creation
Note that we are training generation and not classification, and hence
only the train_loader is loaded
"""
# Transform
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Load data
train_set = datasets.MNIST('mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
for idx, (imgs, _) in enumerate(train_loader):
idx += 1
# Training the discriminator
# Real inputs are actual images of the MNIST dataset
# Fake inputs are from the generator
# Real inputs should be classified as 1 and fake as 0
real_inputs = imgs.to(device)
real_outputs = D(real_inputs)
real_label = torch.ones(real_inputs.shape[0], 1).to(device)
noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)
outputs = torch.cat((real_outputs, fake_outputs), 0)
targets = torch.cat((real_label, fake_label), 0)
D_loss = loss(outputs, targets)
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# Training the generator
# For generator, goal is to make the discriminator believe everything is 1
noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
noise = noise.to(device)
fake_inputs = G(noise)
fake_outputs = D(fake_inputs)
fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
G_loss = loss(fake_outputs, fake_targets)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
if idx % 100 == 0 or idx == len(train_loader):
print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))
if (epoch+1) % 10 == 0:
torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
print('Model saved.')