diff --git a/src/ops/rnn.rs b/src/ops/rnn.rs index 0a6ad13c..12fc8ff9 100644 --- a/src/ops/rnn.rs +++ b/src/ops/rnn.rs @@ -7,8 +7,8 @@ use rten_tensor::{Tensor, TensorView}; use crate::check_dims; use crate::gemm::{GemmExecutor, GemmInputA, GemmInputB}; use crate::ops::{ - add_in_place, mul_in_place, sigmoid_in_place, tanh, tanh_in_place, InputList, IntoOpResult, - OpError, Operator, Output, + add_in_place, mul_in_place, sigmoid, sigmoid_in_place, tanh, tanh_in_place, InputList, + IntoOpResult, OpError, Operator, Output, }; use crate::tensor_pool::{AutoReturn, TensorPool}; @@ -265,12 +265,18 @@ pub fn gru( hidden_scratch_reset_update_gates.as_dyn(), ); - // nb. This is slower than it should be because it falls back to - // the slow path for non-contiguous tensors. - sigmoid_in_place(update_reset_gates.as_dyn_mut()); + // Copy gates before applying activation because `sigmoid_in_place` + // and `tanh_in_place` are slow with non-contiguous tensors. + // See https://github.com/robertknight/rten/issues/192. + // + // Note `gate_range` can be still used because the update and reset + // gates are in the same positions in the `update_reset_gates` slice + // as `gates`. + let update_reset_gates = sigmoid(pool, update_reset_gates.as_dyn()).auto_return(pool); + let update_gate = update_reset_gates.slice::<2, _>((.., gate_range(UPDATE_GATE))); + let reset_gate = update_reset_gates.slice::<2, _>((.., gate_range(RESET_GATE))); // Combine inputs for hidden gate and apply activation. - let reset_gate = gates.slice::<2, _>((.., gate_range(RESET_GATE))); let mut hidden_gate_recurrent = hidden_scratch.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE))); mul_in_place(hidden_gate_recurrent.as_dyn_mut(), reset_gate.as_dyn()); @@ -278,13 +284,11 @@ pub fn gru( let mut hidden_gate = gates.slice_mut::<2, _>((.., gate_range(HIDDEN_GATE))); add_in_place(hidden_gate.as_dyn_mut(), hidden_gate_recurrent.as_dyn()); - // Copy the hidden gate because `tanh_in_place` is slow with - // non-contiguous tensors. + // See note above about `sigmoid_in_place`. let hidden_gate = tanh(pool, hidden_gate.as_dyn()).auto_return(pool); // Compute next hidden state let mut hidden_item = hidden.slice_mut::<2, _>([dir]); - let update_gate = gates.slice::<2, _>((.., gate_range(UPDATE_GATE))); for (hidden, update, hidden_gate) in zip3( hidden_item.iter_mut(),