Skip to content

Commit

Permalink
Added ability to load lexicon.
Browse files Browse the repository at this point in the history
  • Loading branch information
danijel3 committed Oct 12, 2022
1 parent 5353386 commit 67b07b2
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 73 deletions.
Binary file removed KaldiAligner.exe
Binary file not shown.
3 changes: 2 additions & 1 deletion include/lex.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
class Lexicon {
private:
G2P *g2p;
std::map<std::string, std::vector<std::string>> g2p_cache;

double sil_prob;
std::string sil_phone;
Expand All @@ -23,7 +24,7 @@ class Lexicon {
fst::StdVectorFst lexicon_fst;

public:
explicit Lexicon(const std::string &g2p_model_file, const std::string &phone_list_file);
explicit Lexicon(const std::string &lex_file, const std::string &g2p_model_file, const std::string &phone_list_file);
~Lexicon();

void load_file(const std::string &filename);
Expand Down
3 changes: 2 additions & 1 deletion src/g2p.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ class PhonetisaurusModel : public G2P {
for (auto u : d.Uniques) {
w += osyms->Find(u) + " ";
}
ret.push_back(w);
if(!w.empty())
ret.push_back(w);
}
return ret;
}
Expand Down
73 changes: 39 additions & 34 deletions src/kaldi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
KaldiProcess::KaldiProcess(std::string model_dir) {
std::string use_gpu = "yes";
std::string g2p_model_file = model_dir + "/g2p/model.fst";
std::string g2p_cache_file = model_dir + "/g2p/lexicon.txt";
std::string mfcc_config = model_dir + "/conf/mfcc_hires.conf";
std::string model_file = model_dir + "/nnet3/final.mdl";
std::string tree_file = model_dir + "/nnet3/tree";
Expand Down Expand Up @@ -37,7 +38,7 @@ KaldiProcess::KaldiProcess(std::string model_dir) {
CuDevice::Instantiate().SelectGpuId(use_gpu);
#endif

lexicon = std::unique_ptr<Lexicon>(new Lexicon(g2p_model_file, phone_list_file));
lexicon = std::unique_ptr<Lexicon>(new Lexicon(g2p_cache_file, g2p_model_file, phone_list_file));

ParseOptions po("");
MfccOptions mfcc_opts;
Expand All @@ -62,23 +63,23 @@ KaldiProcess::KaldiProcess(std::string model_dir) {
CollapseModel(CollapseModelConfig(), &(am_nnet.GetNnet()));

compiler = std::unique_ptr<CachingOptimizingCompiler>(new CachingOptimizingCompiler(
am_nnet.GetNnet(), decodable_opts.optimize_config));
am_nnet.GetNnet(), decodable_opts.optimize_config));

WordBoundaryInfoNewOpts opts;

word_boundary_info = std::unique_ptr<WordBoundaryInfo>(new WordBoundaryInfo(opts, word_boundary_file));

// std::ofstream trans_debug("trans.txt");
// trans_model.Print(trans_debug, lexicon->get_phonelist());
// trans_debug.close();
// std::ofstream trans_debug("trans.txt");
// trans_model.Print(trans_debug, lexicon->get_phonelist());
// trans_debug.close();

}

void KaldiProcess::MakeLatticeFromLinear(const std::vector<int32> &ali,
const std::vector<int32> &words,
BaseFloat lm_cost,
BaseFloat ac_cost,
Lattice *lat_out) {
void KaldiProcess::MakeLatticeFromLinear(const std::vector<int32>& ali,
const std::vector<int32>& words,
BaseFloat lm_cost,
BaseFloat ac_cost,
Lattice* lat_out) {
typedef LatticeArc::StateId StateId;
typedef LatticeArc::Weight Weight;
typedef LatticeArc::Label Label;
Expand All @@ -90,7 +91,7 @@ void KaldiProcess::MakeLatticeFromLinear(const std::vector<int32> &ali,
Label olabel = (i < words.size() ? words[i] : 0);
StateId next_state = lat_out->AddState();
lat_out->AddArc(cur_state,
LatticeArc(ilabel, olabel, Weight::One(), next_state));
LatticeArc(ilabel, olabel, Weight::One(), next_state));
cur_state = next_state;
}
lat_out->SetFinal(cur_state, Weight(lm_cost, ac_cost));
Expand All @@ -108,13 +109,13 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {
SubVector<BaseFloat> waveform(wave_data.Data(), 0);
mfcc->ComputeFeatures(waveform, wave_data.SampFreq(), 1.0, &features);

// BaseFloatMatrixWriter feat_writer("ark:feats.ark");
// feat_writer.Write("sent001", features);
// feat_writer.Close();
// BaseFloatMatrixWriter feat_writer("ark:feats.ark");
// feat_writer.Write("sent001", features);
// feat_writer.Close();

// Matrix<double> cmvn_stats;
// InitCmvnStats(features.NumCols(), &cmvn_stats);
// AccCmvnStats(features, nullptr, &cmvn_stats);
// Matrix<double> cmvn_stats;
// InitCmvnStats(features.NumCols(), &cmvn_stats);
// AccCmvnStats(features, nullptr, &cmvn_stats);

OnlineIvectorExtractionInfo ivector_info(ivector_config);
OnlineIvectorExtractorAdaptationState adaptation_state(ivector_info);
Expand All @@ -126,8 +127,8 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {
ivector_feature.SetAdaptationState(adaptation_state);

int32 T = features.NumRows(),
n = ivector_config.ivector_period,
num_ivectors = (T + n - 1) / n;
n = ivector_config.ivector_period,
num_ivectors = (T + n - 1) / n;

Matrix<BaseFloat> ivectors(num_ivectors, ivector_feature.Dim());

Expand All @@ -137,9 +138,9 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {
ivector_feature.GetFrame(t, &ivector);
}

// BaseFloatMatrixWriter ivector_writer("ark:ivectors.ark");
// ivector_writer.Write("sent001", ivectors);
// ivector_writer.Close();
// BaseFloatMatrixWriter ivector_writer("ark:ivectors.ark");
// ivector_writer.Write("sent001", ivectors);
// ivector_writer.Close();

TrainingGraphCompilerOptions gopts;
gopts.transition_scale = 0.0; // Change the default to 0.0 since we will generally add the
Expand All @@ -150,7 +151,7 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {

lexicon->load_file(trans_file);

auto *lex_fst = new fst::StdVectorFst(lexicon->get_fst());
auto* lex_fst = new fst::StdVectorFst(lexicon->get_fst());
TrainingGraphCompiler gc(trans_model, tree, lex_fst, disambig_syms, gopts);

std::vector<int32_t> transcription = lexicon->load_transcript(trans_file);
Expand All @@ -159,16 +160,16 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {

gc.CompileGraphFromText(transcription, &graph);

// TableWriter<fst::VectorFstHolder> fst_writer("ark:graph.fst");
// fst_writer.Write("sent001", graph);
// fst_writer.Close();
// TableWriter<fst::VectorFstHolder> fst_writer("ark:graph.fst");
// fst_writer.Write("sent001", graph);
// fst_writer.Close();

AddTransitionProbs(trans_model, disambig_syms, transition_scale, self_loop_scale, &graph);

DecodableAmNnetSimple nnet_decodable(
decodable_opts, trans_model, am_nnet,
features, NULL, &ivectors,
ivector_config.ivector_period, compiler.get());
decodable_opts, trans_model, am_nnet,
features, NULL, &ivectors,
ivector_config.ivector_period, compiler.get());

FasterDecoderOptions decode_opts;
decode_opts.beam = beam;
Expand All @@ -188,9 +189,9 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {

GetLinearSymbolSequence(decoded, &alignment, &words, &weight);

// Int32VectorWriter ali_writer("ark,t:ali.txt");
// ali_writer.Write("sent001", alignment);
// ali_writer.Close();
// Int32VectorWriter ali_writer("ark,t:ali.txt");
// ali_writer.Write("sent001", alignment);
// ali_writer.Close();

Lattice lat;
MakeLatticeFromLinear(alignment, transcription, 0, 0, &lat);
Expand All @@ -212,8 +213,12 @@ Result KaldiProcess::process(std::string wav_file, std::string trans_file) {
CompactLatticeToWordAlignment(aligned_clat, &words, &times, &lengths);

auto phone_text = lexicon->int2phones(words);
for (int i = 0; i < phone_text.size(); i++)
ret.phones.emplace_back(phone_text[i], times[i] / 100.0, (times[i] + lengths[i]) / 100.0);
for (int i = 0; i < phone_text.size(); i++) {
auto ph = phone_text[i];
if (ph != "sil" && ph != "sp") {
ret.phones.emplace_back(ph, times[i] / 100.0, (times[i] + lengths[i]) / 100.0);
}
}

return ret;
}
Loading

0 comments on commit 67b07b2

Please sign in to comment.