Skip to content

Commit

Permalink
cifar10/general_child.py shape fix for tf1.5 and higher resolves melo…
Browse files Browse the repository at this point in the history
  • Loading branch information
funasoul committed Nov 4, 2022
1 parent c9ada2b commit 840fabc
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/cifar10/general_child.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _factorized_reduction(self, x, out_filters, stride, is_training):
w = create_weight("w", [1, 1, inp_c, out_filters // 2])
path1 = tf.nn.conv2d(path1, w, [1, 1, 1, 1], "SAME",
data_format=self.data_format)

# Skip path 2
# First pad with 0"s on the right and bottom, then shift the filter to
# include those 0"s that were added.
Expand All @@ -160,15 +160,15 @@ def _factorized_reduction(self, x, out_filters, stride, is_training):
pad_arr = [[0, 0], [0, 0], [0, 1], [0, 1]]
path2 = tf.pad(x, pad_arr)[:, :, 1:, 1:]
concat_axis = 1

path2 = tf.nn.avg_pool(
path2, [1, 1, 1, 1], stride_spec, "VALID", data_format=self.data_format)
with tf.variable_scope("path2_conv"):
inp_c = self._get_C(path2)
w = create_weight("w", [1, 1, inp_c, out_filters // 2])
path2 = tf.nn.conv2d(path2, w, [1, 1, 1, 1], "SAME",
data_format=self.data_format)

# Concat and apply BN
final_path = tf.concat(values=[path1, path2], axis=concat_axis)
final_path = batch_norm(final_path, is_training,
Expand Down Expand Up @@ -291,13 +291,14 @@ def _enas_layer(self, layer_id, prev_layers, start_idx, out_filters, is_training
y = self._pool_branch(inputs, is_training, out_filters, "max",
start_idx=0)
branches[tf.equal(count, 5)] = lambda: y
out = tf.case(branches, default=lambda: tf.constant(0, tf.float32),
exclusive=True)

if self.data_format == "NHWC":
out.set_shape([None, inp_h, inp_w, out_filters])
out_shape = [self.batch_size, inp_h, inp_w, out_filters]
elif self.data_format == "NCHW":
out.set_shape([None, out_filters, inp_h, inp_w])
out_shape = [self.batch_size, out_filters, inp_h, inp_w]

out = tf.case(branches, default=lambda: tf.constant(0, tf.float32, shape=out_shape),
exclusive=True)
else:
count = self.sample_arc[start_idx:start_idx + 2 * self.num_branches]
branches = []
Expand Down

0 comments on commit 840fabc

Please sign in to comment.