Skip to content

Commit

Permalink
Decoding works. but results are not perfect. For first few frames, it…
Browse files Browse the repository at this point in the history
…s good, then incorrect predictions.
  • Loading branch information
sangeet2020 committed May 29, 2024
1 parent e1613b6 commit f9633f6
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 18 deletions.
1 change: 1 addition & 0 deletions sherpa-onnx/csrc/online-recognizer-transducer-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ static OnlineRecognizerResult Convert(const OnlineTransducerDecoderResult &src,
r.timestamps.reserve(src.tokens.size());

for (auto i : src.tokens) {
if (i == -1) continue;
auto sym = sym_table[i];

r.text.append(sym);
Expand Down
14 changes: 6 additions & 8 deletions sherpa-onnx/csrc/online-recognizer-transducer-nemo-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {

std::vector<OnlineTransducerDecoderResult> result(n);
std::vector<float> features_vec(n * chunk_size * feature_dim);
std::vector<std::vector<Ort::Value>> states_vec(n);
std::vector<std::vector<Ort::Value>> encoder_states(n);

for (int32_t i = 0; i != n; ++i) {
const auto num_processed_frames = ss[i]->GetNumProcessedFrames();
Expand All @@ -167,7 +167,7 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
features_vec.data() + i * chunk_size * feature_dim);

result[i] = std::move(ss[i]->GetResult());
states_vec[i] = std::move(ss[i]->GetStates());
encoder_states[i] = std::move(ss[i]->GetStates());

}

Expand All @@ -181,8 +181,8 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
x_shape.size());

// Batch size is 1
auto states = std::move(states_vec[0]);
int32_t num_states = states.size();
auto states = std::move(encoder_states[0]);
int32_t num_states = states.size(); // num_states = 3
auto t = model_->RunEncoder(std::move(x), std::move(states));
// t[0] encoder_out, float tensor, (batch_size, dim, T)
// t[1] next states
Expand All @@ -203,14 +203,12 @@ class OnlineRecognizerTransducerNeMoImpl : public OnlineRecognizerImpl {
// Subsequent decoder states (for each chunks) are updated inside the Decode method.
// This returns the decoder state from the LAST chunk. We probably dont need it. So we can discard it.
decoder_states = decoder_->Decode(std::move(encoder_out),
std::move(decoder_states),
std::move(decoder_states),
&result, ss, n);


ss[0]->SetResult(result[0]);

// We probably dont need it. Will discard it.
ss[0]->SetStates(std::move(decoder_states));
ss[0]->SetStates(std::move(out_states));
}

void InitOnlineStream(OnlineStream *stream) const {
Expand Down
18 changes: 9 additions & 9 deletions sherpa-onnx/csrc/online-transducer-greedy-search-nemo-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ std::vector<Ort::Value> DecodeOne(
// decoder_output_pair.second returns the next decoder state
std::pair<Ort::Value, std::vector<Ort::Value>> decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_states));
std::move(decoder_states)); // here decoder_states = {len=0, cap=0}. But decoder_output_pair= {first, second: {len=2, cap=2}} // ATTN

std::array<int64_t, 3> encoder_shape{1, num_cols, 1};

decoder_states = std::move(decoder_output_pair.second);

// start with each chunks in the input sequence. Is this loop really meant for that?
// TODO: Inside this loop, I need to framewise decoding.
for (int32_t t = 0; t != num_rows; ++t) {
Ort::Value cur_encoder_out = Ort::Value::CreateTensor(
memory_info, const_cast<float *>(encoder_out) + t * num_cols, num_cols,
Expand All @@ -117,7 +117,7 @@ std::vector<Ort::Value> DecodeOne(
static_cast<const float *>(p_logit),
std::max_element(static_cast<const float *>(p_logit),
static_cast<const float *>(p_logit) + vocab_size)));

SHERPA_ONNX_LOGE("y=%d", y);
if (y != blank_id) {
r.tokens.push_back(y);
r.timestamps.push_back(t + r.frame_offset);
Expand All @@ -128,14 +128,14 @@ std::vector<Ort::Value> DecodeOne(
decoder_output_pair =
model->RunDecoder(std::move(decoder_input_pair.first),
std::move(decoder_states));

// Update the decoder states for the next chunk
decoder_states = std::move(decoder_output_pair.second);
}

// Update the decoder states for the next chunk
decoder_states = std::move(decoder_output_pair.second);
}

decoder_out = std::move(decoder_output_pair.first);
// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);
// UpdateCachedDecoderOut(model->Allocator(), &decoder_out, result);

// Update frame_offset
for (auto &r : *result) {
Expand Down Expand Up @@ -163,8 +163,8 @@ std::vector<Ort::Value> OnlineTransducerGreedySearchNeMoDecoder::Decode(
}

int32_t batch_size = static_cast<int32_t>(shape[0]); // bs = 1
int32_t dim1 = static_cast<int32_t>(shape[1]);
int32_t dim2 = static_cast<int32_t>(shape[2]);
int32_t dim1 = static_cast<int32_t>(shape[1]); // 2
int32_t dim2 = static_cast<int32_t>(shape[2]); // 512

// Define and initialize encoder_out_length
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Expand Down
2 changes: 1 addition & 1 deletion sherpa-onnx/csrc/online-transducer-nemo-model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class OnlineTransducerNeMoModel::Impl {

std::vector<Ort::Value> RunEncoder(Ort::Value features,
std::vector<Ort::Value> states) {
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_channel = states[0];
Ort::Value &cache_last_time = states[1];
Ort::Value &cache_last_channel_len = states[2];

Expand Down

0 comments on commit f9633f6

Please sign in to comment.