forked from tensorpack/tensorpack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGAN.py
114 lines (91 loc) · 4.51 KB
/
GAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: GAN.py
# Author: Yuxin Wu <[email protected]>
import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainerBase, QueueInput, ModelDesc, DataFlow)
from tensorpack.tfutils.summary import add_moving_summary
class GANModelDesc(ModelDesc):
def collect_variables(self, g_scope='gen', d_scope='discrim'):
"""
Assign self.g_vars to the parameters under scope `g_scope`,
and same with self.d_vars.
"""
self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Args:
logits_real (tf.Tensor): discrim logits from real samples
logits_fake (tf.Tensor): discrim logits from fake samples produced by generator
"""
with tf.name_scope("GAN_loss"):
score_real = tf.sigmoid(logits_real)
score_fake = tf.sigmoid(logits_fake)
tf.summary.histogram('score-real', score_real)
tf.summary.histogram('score-fake', score_fake)
with tf.name_scope("discrim"):
d_loss_pos = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_real, labels=tf.ones_like(logits_real)), name='loss_real')
d_loss_neg = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.zeros_like(logits_fake)), name='loss_fake')
d_pos_acc = tf.reduce_mean(tf.cast(score_real > 0.5, tf.float32), name='accuracy_real')
d_neg_acc = tf.reduce_mean(tf.cast(score_fake < 0.5, tf.float32), name='accuracy_fake')
d_accuracy = tf.add(.5 * d_pos_acc, .5 * d_neg_acc, name='accuracy')
self.d_loss = tf.add(.5 * d_loss_pos, .5 * d_loss_neg, name='loss')
with tf.name_scope("gen"):
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits_fake, labels=tf.ones_like(logits_fake)), name='loss')
g_accuracy = tf.reduce_mean(tf.cast(score_fake > 0.5, tf.float32), name='accuracy')
add_moving_summary(self.g_loss, self.d_loss, d_accuracy, g_accuracy)
class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config):
self._input_method = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config)
def _setup(self):
super(GANTrainer, self)._setup()
self.build_train_tower()
opt = self.model.get_optimizer()
# by default, run one d_min after one g_min
self.g_min = opt.minimize(self.model.g_loss, var_list=self.model.g_vars, name='g_op')
with tf.control_dependencies([self.g_min]):
self.d_min = opt.minimize(self.model.d_loss, var_list=self.model.d_vars, name='d_op')
self.train_op = self.d_min
class SeparateGANTrainer(FeedfreeTrainerBase):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_period=1, g_period=1):
"""
Args:
d_period(int): period of each d_opt run
g_period(int): period of each g_opt run
"""
self._input_method = QueueInput(config.dataflow)
self._d_period = int(d_period)
self._g_period = int(g_period)
assert min(d_period, g_period) == 1
super(SeparateGANTrainer, self).__init__(config)
def _setup(self):
super(SeparateGANTrainer, self)._setup()
self.build_train_tower()
opt = self.model.get_optimizer()
self.d_min = opt.minimize(
self.model.d_loss, var_list=self.model.d_vars, name='d_min')
self.g_min = opt.minimize(
self.model.g_loss, var_list=self.model.g_vars, name='g_min')
self._cnt = 1
def run_step(self):
if self._cnt % (self._d_period) == 0:
self.hooked_sess.run(self.d_min)
if self._cnt % (self._g_period) == 0:
self.hooked_sess.run(self.g_min)
self._cnt += 1
class RandomZData(DataFlow):
def __init__(self, shape):
super(RandomZData, self).__init__()
self.shape = shape
def get_data(self):
while True:
yield [np.random.uniform(-1, 1, size=self.shape)]