You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm using the attention-GRU-piecewise-linear RUL Prediction.ipynb notebook. There is an error in the AdditiveAttentionForSeq class, specifically in the call function at the line concat = tf.concat((state_rep, encoder_outputs), axis = -1). It tries to concatenate a 4D array with a 3D array, which is not possible.
However, I'm not entirely sure if this is the correct approach to fix the error. Could you please help verify this fix or suggest the correct way to handle this issue?
The text was updated successfully, but these errors were encountered:
Hello,
I'm using the attention-GRU-piecewise-linear RUL Prediction.ipynb notebook. There is an error in the AdditiveAttentionForSeq class, specifically in the call function at the line concat = tf.concat((state_rep, encoder_outputs), axis = -1). It tries to concatenate a 4D array with a 3D array, which is not possible.
I modified the call function as follows:
def call(self, state, encoder_outputs):
seq_len = encoder_outputs.shape[1]
averaged_state = tf.reduce_mean(tf.stack(state, axis = 1), axis = 1)
state_rep = tf.repeat(tf.expand_dims(averaged_state, axis = 1), repeats = seq_len, axis = 1)
shape = tf.shape(state_rep)
seq_len = shape[1]
batch_size = shape[2]
hidden_dims2 = shape[3]
state_rep = tf.reshape(state_rep, (batch_size, seq_len, hidden_dims2))
concat = tf.concat((state_rep, encoder_outputs), axis = -1)
scores = tf.nn.tanh(self.attention(concat))
attention_weights = tf.nn.softmax(tf.reduce_sum(scores, axis = -1), axis = -1)
return tf.matmul(tf.expand_dims(attention_weights, axis = 1), encoder_outputs)
However, I'm not entirely sure if this is the correct approach to fix the error. Could you please help verify this fix or suggest the correct way to handle this issue?
The text was updated successfully, but these errors were encountered: