Skip to content

Commit

Permalink
Changed all text_hash from string to size_t for more performance.
Browse files Browse the repository at this point in the history
Also solved bug by accidentally pasted extra code in previous commit
  • Loading branch information
Your Name committed Jan 25, 2025
1 parent 521e900 commit 1291ed2
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 68 deletions.
82 changes: 43 additions & 39 deletions sherpa-onnx/csrc/offline-tts-cache-mechanism.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <filesystem>
#include <iostream>
#include <limits>
#include <cstddef> // for std::size_t

#include "sherpa-onnx/csrc/file-utils.h"
#include "sherpa-onnx/csrc/macros.h"
Expand Down Expand Up @@ -60,14 +61,14 @@ OfflineTtsCacheMechanism::~OfflineTtsCacheMechanism() {
}

void OfflineTtsCacheMechanism::AddWavFile(
const std::string &text_hash,
const std::size_t &text_hash,
const std::vector<float> &samples,
const int32_t sample_rate) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

if (cache_mechanism_inited_ == false) return;

std::string file_path = cache_dir_ + "/" + text_hash + ".wav";
std::string file_path = cache_dir_ + "/" + std::to_string(text_hash) + ".wav";

// Check if the file physically exists in the cache directory
bool file_exists = std::filesystem::exists(file_path);
Expand All @@ -92,15 +93,15 @@ void OfflineTtsCacheMechanism::AddWavFile(
}

std::vector<float> OfflineTtsCacheMechanism::GetWavFile(
const std::string &text_hash,
const std::size_t &text_hash,
int32_t *sample_rate) {
std::lock_guard<std::recursive_mutex> lock(mutex_);

std::vector<float> samples;

if (cache_mechanism_inited_ == false) return samples;

std::string file_path = cache_dir_ + "/" + text_hash + ".wav";
std::string file_path = cache_dir_ + "/" + std::to_string(text_hash) + ".wav";

if (std::filesystem::exists(file_path)) {
bool is_ok = false;
Expand All @@ -119,12 +120,12 @@ std::vector<float> OfflineTtsCacheMechanism::GetWavFile(
}

// Save the repeat counts every 10 minutes
auto now = std::chrono::steady_clock::now();
if (std::chrono::duration_cast<std::chrono::seconds>(
now - last_save_time_).count() >= 10 * 60) {
//auto now = std::chrono::steady_clock::now();
//if (std::chrono::duration_cast<std::chrono::seconds>(
//now - last_save_time_).count() >= 10 * 60) {
SaveRepeatCounts();
last_save_time_ = now;
}
//last_save_time_ = now;
//}

return samples;
}
Expand Down Expand Up @@ -168,7 +169,7 @@ void OfflineTtsCacheMechanism::ClearCache() {
repeat_counts_.clear();
cache_vector_.clear();

// Remove repeat counts also in the repeat_counts.txt
// Remove repeat counts also in the repeat_counts file
SaveRepeatCounts();
}

Expand All @@ -183,58 +184,60 @@ int32_t OfflineTtsCacheMechanism::GetTotalUsedCacheSize() const {
// Private functions ///////////////////////////////////////////////////

void OfflineTtsCacheMechanism::LoadRepeatCounts() {
std::string repeat_count_file = cache_dir_ + "/repeat_counts.txt";
std::string repeat_count_file = cache_dir_ + "/repeat_counts.bin";

// Check if the file exists
if (!std::filesystem::exists(repeat_count_file)) {
return; // Skip loading if the file doesn't exist
}

// Open the file for reading
std::ifstream ifs(repeat_count_file);
// Open the file for reading in binary mode
std::ifstream ifs(repeat_count_file, std::ios::binary);
if (!ifs.is_open()) {
SHERPA_ONNX_LOGE("Failed to open repeat count file: %s",
repeat_count_file.c_str());
return; // Skip loading if the file cannot be opened
}

// Read the file line by line
std::string line;
while (std::getline(ifs, line)) {
size_t pos = line.find(' ');
if (pos != std::string::npos) {
std::string text_hash = line.substr(0, pos);
int32_t count = std::stoi(line.substr(pos + 1));
repeat_counts_[text_hash] = count;
}
// Read the number of entries
size_t num_entries;
ifs.read(reinterpret_cast<char*>(&num_entries), sizeof(num_entries));

// Read each entry
for (size_t i = 0; i < num_entries; ++i) {
std::size_t text_hash;
int32_t count;
ifs.read(reinterpret_cast<char*>(&text_hash), sizeof(text_hash));
ifs.read(reinterpret_cast<char*>(&count), sizeof(count));
repeat_counts_[text_hash] = count;
}
}

void OfflineTtsCacheMechanism::SaveRepeatCounts() {
std::string repeat_count_file = cache_dir_ + "/repeat_counts.txt";
std::string repeat_count_file = cache_dir_ + "/repeat_counts.bin";

// Open the file for writing
std::ofstream ofs(repeat_count_file);
// Open the file for writing in binary mode
std::ofstream ofs(repeat_count_file, std::ios::binary);
if (!ofs.is_open()) {
SHERPA_ONNX_LOGE("Failed to open repeat count file for writing: %s",
repeat_count_file.c_str());
return; // Skip saving if the file cannot be opened
}

// Write the repeat counts to the file
// Write the number of entries
size_t num_entries = repeat_counts_.size();
ofs.write(reinterpret_cast<const char*>(&num_entries), sizeof(num_entries));

// Write each entry
for (const auto &entry : repeat_counts_) {
ofs << entry.first << " " << entry.second;
if (!ofs) {
SHERPA_ONNX_LOGE("Failed to write repeat count for text hash: %s",
entry.first.c_str());
return; // Stop writing if an error occurs
}
ofs << std::endl;
ofs.write(reinterpret_cast<const char*>(&entry.first), sizeof(entry.first));
ofs.write(reinterpret_cast<const char*>(&entry.second), sizeof(entry.second));
}
}

void OfflineTtsCacheMechanism::RemoveWavFile(const std::string &text_hash) {
std::string file_path = cache_dir_ + "/" + text_hash + ".wav";
void OfflineTtsCacheMechanism::RemoveWavFile(const std::size_t &text_hash) {
std::string file_path = cache_dir_ + "/"
+ std::to_string(text_hash) + ".wav";
if (std::filesystem::exists(file_path)) {
// Subtract the size of the removed WAV file from the total cache size
std::ifstream file(file_path, std::ios::binary | std::ios::ate);
Expand All @@ -259,7 +262,8 @@ void OfflineTtsCacheMechanism::UpdateCacheVector() {

for (const auto &entry : std::filesystem::directory_iterator(cache_dir_)) {
if (entry.path().extension() == ".wav") {
std::string text_hash = entry.path().stem().string();
std::string text_hash_str = entry.path().stem().string();
std::size_t text_hash = std::stoull(text_hash_str);
if (repeat_counts_.find(text_hash) == repeat_counts_.end()) {
// Remove the file if it's not in the repeat count file (orphaned file)
std::filesystem::remove(entry.path());
Expand All @@ -282,14 +286,14 @@ void OfflineTtsCacheMechanism::EnsureCacheLimit() {
while (used_cache_size_bytes_> 0
&& used_cache_size_bytes_ > target_cache_size) {
// Cache is full, remove the least repeated file
std::string least_repeated_file = GetLeastRepeatedFile();
std::size_t least_repeated_file = GetLeastRepeatedFile();
RemoveWavFile(least_repeated_file);
}
}
}

std::string OfflineTtsCacheMechanism::GetLeastRepeatedFile() {
std::string least_repeated_file;
std::size_t OfflineTtsCacheMechanism::GetLeastRepeatedFile() {
std::size_t least_repeated_file = 0;
int32_t min_count = std::numeric_limits<int32_t>::max();

for (const auto &entry : repeat_counts_) {
Expand Down
14 changes: 7 additions & 7 deletions sherpa-onnx/csrc/offline-tts-cache-mechanism.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
#include <vector>
#include <unordered_map>
#include <mutex> // NOLINT
#include <cstddef> // for std::size_t

#include "sherpa-onnx/csrc/offline-tts-cache-mechanism-config.h"

namespace sherpa_onnx {

class OfflineTtsCacheMechanism {
public:

explicit OfflineTtsCacheMechanism(const OfflineTtsCacheMechanismConfig &config);
~OfflineTtsCacheMechanism();

// Add a new wav file to the cache
void AddWavFile(
const std::string &text_hash,
const std::size_t &text_hash,
const std::vector<float> &samples,
const int32_t sample_rate);

// Get the cached wav file if it exists
std::vector<float> GetWavFile(
const std::string &text_hash,
const std::size_t &text_hash,
int32_t *sample_rate);

// Get the current cache size in bytes
Expand All @@ -51,7 +51,7 @@ class OfflineTtsCacheMechanism {
void SaveRepeatCounts();

// Remove a wav file from the cache
void RemoveWavFile(const std::string &text_hash);
void RemoveWavFile(const std::size_t &text_hash);

// Update the cache vector with the actual files in the cache folder
void UpdateCacheVector();
Expand All @@ -60,7 +60,7 @@ class OfflineTtsCacheMechanism {
void EnsureCacheLimit();

// Get the least repeated file in the cache
std::string GetLeastRepeatedFile();
std::size_t GetLeastRepeatedFile();

// Data directory where the cache folder is located
std::string cache_dir_;
Expand All @@ -72,10 +72,10 @@ class OfflineTtsCacheMechanism {
int32_t used_cache_size_bytes_;

// Map of text hash to repeat count
std::unordered_map<std::string, int32_t> repeat_counts_;
std::unordered_map<std::size_t, int32_t> repeat_counts_;

// Vector of cached file names
std::vector<std::string> cache_vector_;
std::vector<std::size_t> cache_vector_;

// Mutex for thread safety (recursive to avoid deadlocks)
mutable std::recursive_mutex mutex_;
Expand Down
6 changes: 2 additions & 4 deletions sherpa-onnx/csrc/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,7 @@ GeneratedAudio OfflineTts::Generate(
GeneratedAudioCallback callback /*= nullptr*/) const {
// Generate a hash for the text
std::hash<std::string> hasher;
std::string text_hash = std::to_string(hasher(text));
// SHERPA_ONNX_LOGE("Generated text hash: %s", text_hash.c_str());
std::size_t text_hash = hasher(text);

// Check if the cache mechanism is active and if the audio is already cached
if (cache_mechanism_) {
Expand All @@ -122,7 +121,7 @@ GeneratedAudio OfflineTts::Generate(
= cache_mechanism_->GetWavFile(text_hash, &sample_rate);

if (!samples.empty()) {
SHERPA_ONNX_LOGE("Returning cached audio for hash:%s", text_hash.c_str());
SHERPA_ONNX_LOGE("Returning cached audio for hash: %zu", text_hash);

// If a callback is provided, call it with the cached audio
if (callback) {
Expand All @@ -146,7 +145,6 @@ GeneratedAudio OfflineTts::Generate(
// Cache the generated audio if the cache mechanism is active
if (cache_mechanism_) {
cache_mechanism_->AddWavFile(text_hash, audio.samples, audio.sample_rate);
// SHERPA_ONNX_LOGE("Cached audio for text hash: %s", text_hash.c_str());
}

return audio;
Expand Down
18 changes: 0 additions & 18 deletions sherpa-onnx/jni/offline-tts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,6 @@ static OfflineTtsCacheMechanismConfig GetOfflineTtsCacheConfig(JNIEnv *env, jobj
return ans;
}

// Get data directory from config
jfieldID model_fid = env->GetFieldID(cls, "model", "Lcom/k2fsa/sherpa/onnx/OfflineTtsModelConfig;");
jobject model_config = env->GetObjectField(config, model_fid);
jclass model_cls = env->GetObjectClass(model_config);

jfieldID vits_fid = env->GetFieldID(model_cls, "vits", "Lcom/k2fsa/sherpa/onnx/OfflineTtsVitsModelConfig;");
jobject vits_config = env->GetObjectField(model_config, vits_fid);

fid = env->GetFieldID(vits_cls, "dataDir", "Ljava/lang/String;");
jstring data_dir = (jstring)env->GetObjectField(vits_config, fid);
const char *p_data_dir = env->GetStringUTFChars(data_dir, nullptr);

// Convert data directory to cache directory
std::string cache_dir = std::string(p_data_dir) + "/../cache";
ans.cache_dir = cache_dir;

env->ReleaseStringUTFChars(data_dir, p_data_dir);

} // namespace sherpa_onnx

SHERPA_ONNX_EXTERN_C
Expand Down

0 comments on commit 1291ed2

Please sign in to comment.