From 840fabc660fddfc2896512de3e12759a6f9baf78 Mon Sep 17 00:00:00 2001 From: Akira Funahashi Date: Sat, 5 Nov 2022 04:34:32 +0900 Subject: [PATCH] cifar10/general_child.py shape fix for tf1.5 and higher resolves #4 #29 taken from https://github.com/melodyguan/enas/pull/29/commits/ad9ec6c85a0c7b7e97bea839528660cbde2554b2 --- src/cifar10/general_child.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/cifar10/general_child.py b/src/cifar10/general_child.py index c7d998f..623e3a2 100644 --- a/src/cifar10/general_child.py +++ b/src/cifar10/general_child.py @@ -148,7 +148,7 @@ def _factorized_reduction(self, x, out_filters, stride, is_training): w = create_weight("w", [1, 1, inp_c, out_filters // 2]) path1 = tf.nn.conv2d(path1, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) - + # Skip path 2 # First pad with 0"s on the right and bottom, then shift the filter to # include those 0"s that were added. @@ -160,7 +160,7 @@ def _factorized_reduction(self, x, out_filters, stride, is_training): pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]] path2 = tf.pad(x, pad_arr)[:, :, 1:, 1:] concat_axis = 1 - + path2 = tf.nn.avg_pool( path2, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format) with tf.variable_scope("path2_conv"): @@ -168,7 +168,7 @@ def _factorized_reduction(self, x, out_filters, stride, is_training): w = create_weight("w", [1, 1, inp_c, out_filters // 2]) path2 = tf.nn.conv2d(path2, w, [1, 1, 1, 1], "SAME", data_format=self.data_format) - + # Concat and apply BN final_path = tf.concat(values=[path1, path2], axis=concat_axis) final_path = batch_norm(final_path, is_training, @@ -291,13 +291,14 @@ def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training y = self._pool_branch(inputs, is_training, out_filters, "max", start_idx=0) branches[tf.equal(count, 5)] = lambda: y - out = tf.case(branches, default=lambda: tf.constant(0, tf.float32), - exclusive=True) if self.data_format == "NHWC": - out.set_shape([None, inp_h, inp_w, out_filters]) + out_shape = [self.batch_size, inp_h, inp_w, out_filters] elif self.data_format == "NCHW": - out.set_shape([None, out_filters, inp_h, inp_w]) + out_shape = [self.batch_size, out_filters, inp_h, inp_w] + + out = tf.case(branches, default=lambda: tf.constant(0, tf.float32, shape=out_shape), + exclusive=True) else: count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches] branches = []