From 2811ad1e397eb6287d0ac2b785987671df72eefb Mon Sep 17 00:00:00 2001 From: Yang Song Date: Sun, 18 Mar 2018 00:40:52 -0700 Subject: [PATCH] Fix ACGAN training The ACGAN training procedure is not optimal --- ACGAN.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/ACGAN.py b/ACGAN.py index 208cea54..805e0d84 100644 --- a/ACGAN.py +++ b/ACGAN.py @@ -143,23 +143,22 @@ def build_model(self): # get information loss self.q_loss = q_fake_loss + q_real_loss + + self.d_loss += q_fake_loss + q_real_loss + self.g_loss += q_fake_loss """ Training """ # divide trainable variables into a group for D and a group for G - t_vars = tf.trainable_variables() - d_vars = [var for var in t_vars if 'd_' in var.name] - g_vars = [var for var in t_vars if 'g_' in var.name] - q_vars = [var for var in t_vars if ('d_' in var.name) or ('c_' in var.name) or ('g_' in var.name)] - + d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator') + g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') + # optimizers with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=self.beta1) \ .minimize(self.d_loss, var_list=d_vars) self.g_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \ .minimize(self.g_loss, var_list=g_vars) - self.q_optim = tf.train.AdamOptimizer(self.learning_rate * 5, beta1=self.beta1) \ - .minimize(self.q_loss, var_list=q_vars) - + """" Testing """ # for test self.fake_images = self.generator(self.z, self.y, is_training=False, reuse=True) @@ -225,8 +224,8 @@ def train(self): self.writer.add_summary(summary_str, counter) # update G & Q network - _, summary_str_g, g_loss, _, summary_str_q, q_loss = self.sess.run( - [self.g_optim, self.g_sum, self.g_loss, self.q_optim, self.q_sum, self.q_loss], + _, summary_str_g, g_loss, summary_str_q, q_loss = self.sess.run( + [self.g_optim, self.g_sum, self.g_loss, self.q_sum, self.q_loss], feed_dict={self.z: batch_z, self.y: batch_codes, self.inputs: batch_images}) self.writer.add_summary(summary_str_g, counter) self.writer.add_summary(summary_str_q, counter) @@ -334,4 +333,4 @@ def load(self, checkpoint_dir): return True, counter else: print(" [*] Failed to find a checkpoint") - return False, 0 \ No newline at end of file + return False, 0