Skip to content

Commit

Permalink
fix: enhance the tokenizer's handing of Unicode (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
leejet authored Dec 20, 2023
1 parent 9842a3f commit 8f6b4a3
Show file tree
Hide file tree
Showing 6 changed files with 43,819 additions and 80,106 deletions.
17 changes: 3 additions & 14 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1192,20 +1192,9 @@ ggml_type ModelLoader::get_sd_wtype() {
return GGML_TYPE_COUNT;
}

bool ModelLoader::load_vocab(on_new_token_cb_t on_new_token_cb) {
char* vocab_buffer = reinterpret_cast<char*>(vocab_json);
nlohmann::json vocab = nlohmann::json::parse(vocab_buffer);
std::map<char, int> decoder = unicode_to_byte();
for (auto& it : vocab.items()) {
int token_id = it.value();
std::string token_str = it.key();
std::string token = "";
for (char c : token_str) {
token += decoder[c];
}
on_new_token_cb(token, token_id);
}
return true;
std::string ModelLoader::load_merges() {
std::string merges_utf8_str(reinterpret_cast<const char*>(merges_utf8_c_str), sizeof(merges_utf8_c_str));
return merges_utf8_str;
}

bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb) {
Expand Down
2 changes: 1 addition & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ModelLoader {
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
SDVersion get_sd_version();
ggml_type get_sd_wtype();
bool load_vocab(on_new_token_cb_t on_new_token_cb);
std::string load_merges();
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb);
int64_t cal_mem_size(ggml_backend_t backend);
~ModelLoader() = default;
Expand Down
125 changes: 103 additions & 22 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,40 @@ const int BOS_TOKEN_ID = 49406;
const int EOS_TOKEN_ID = 49407;
const int PAD_TOKEN_ID = 49407;

std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
std::set<int> byte_set;
for (int b = static_cast<int>('!'); b <= static_cast<int>('~'); ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 161; b <= 172; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
for (int b = 174; b <= 255; ++b) {
byte_set.insert(b);
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(b)));
}
int n = 0;
for (int b = 0; b < 256; ++b) {
if (byte_set.find(b) == byte_set.end()) {
byte_unicode_pairs.push_back(std::pair<int, std::u32string>(b, unicode_value_to_utf32(n + 256)));
++n;
}
}
// LOG_DEBUG("byte_unicode_pairs %d", byte_unicode_pairs.size());
return byte_unicode_pairs;
}

// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py
// TODO: implement bpe
class CLIPTokenizer {
private:
SDVersion version = VERSION_1_x;
std::map<std::string, int32_t> encoder;
std::map<int, std::u32string> byte_encoder;
std::map<std::u32string, int> encoder;
std::map<std::pair<std::u32string, std::u32string>, int> bpe_ranks;
std::regex pat;

static std::string strip(const std::string& str) {
Expand All @@ -521,19 +549,61 @@ class CLIPTokenizer {

public:
CLIPTokenizer(SDVersion version = VERSION_1_x)
: version(version){};
std::string bpe(std::string token) {
std::string word = token + "</w>";
: version(version) {}

void load_from_merges(const std::string& merges_utf8_str) {
auto byte_unicode_pairs = bytes_to_unicode();
byte_encoder = std::map<int, std::u32string>(byte_unicode_pairs.begin(), byte_unicode_pairs.end());
// for (auto & pair: byte_unicode_pairs) {
// std::cout << pair.first << ": " << pair.second << std::endl;
// }
std::vector<std::u32string> merges;
size_t start = 0;
size_t pos;
std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str);
while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) {
merges.push_back(merges_utf32_str.substr(start, pos - start));
start = pos + 1;
}
merges = std::vector<std::u32string>(merges.begin() + 1, merges.begin() + 49152 - 256 - 2 + 1);
std::vector<std::pair<std::u32string, std::u32string>> merge_pairs;
for (const auto& merge : merges) {
size_t space_pos = merge.find(' ');
merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1));
// LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str());
}
std::vector<std::u32string> vocab;
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second);
}
for (const auto& pair : byte_unicode_pairs) {
vocab.push_back(pair.second + utf8_to_utf32("</w>"));
}
for (const auto& merge : merge_pairs) {
vocab.push_back(merge.first + merge.second);
}
vocab.push_back(utf8_to_utf32("<|startoftext|>"));
vocab.push_back(utf8_to_utf32("<|endoftext|>"));
LOG_DEBUG("vocab size: %llu", vocab.size());
int i = 0;
for (const auto& token : vocab) {
encoder[token] = i++;
}

int rank = 0;
for (const auto& merge : merge_pairs) {
bpe_ranks[merge] = rank++;
}
};

std::u32string bpe(std::u32string token) {
std::u32string word = token + utf8_to_utf32("</w>");
if (encoder.find(word) != encoder.end()) {
return word;
} else if (encoder.find(token) != encoder.end()) {
return token;
}
return UNK_TOKEN;
}

void add_token(std::string token, int32_t token_id) {
encoder[token] = token_id;
return utf8_to_utf32(UNK_TOKEN);
}

std::vector<int> tokenize(std::string text, size_t max_length = 0, bool padding = false) {
Expand Down Expand Up @@ -571,13 +641,25 @@ class CLIPTokenizer {
std::vector<std::string> token_strs;
while (std::regex_search(str, matches, pat)) {
for (auto& token : matches) {
std::istringstream iss(bpe(token));
std::vector<std::string> tokens{std::istream_iterator<std::string>{iss},
std::istream_iterator<std::string>{}};
for (const auto& bpe_token : tokens) {
bpe_tokens.push_back(encoder[bpe_token]);
token_strs.push_back(bpe_token);
std::string token_str = token.str();
std::u32string utf32_token;
for (int i = 0; i < token_str.length(); i++) {
char b = token_str[i];
utf32_token += byte_encoder[b];
}
auto bpe_strs = bpe(utf32_token);
size_t start = 0;
size_t pos;
while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) {
auto bpe_str = bpe_strs.substr(start, pos - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));

start = pos + 1;
}
auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start);
bpe_tokens.push_back(encoder[bpe_str]);
token_strs.push_back(utf32_to_utf8(bpe_str));
}
str = matches.suffix();
}
Expand Down Expand Up @@ -4323,15 +4405,14 @@ class StableDiffusionGGML {
LOG_INFO("Stable Diffusion weight type: %s", ggml_type_name(model_data_type));

LOG_DEBUG("loading vocab");
auto add_token = [&](const std::string& token, int32_t token_id) {
cond_stage_model.tokenizer.add_token(token, token_id);
};
bool success = model_loader.load_vocab(add_token);
if (!success) {
LOG_ERROR("get vocab from file failed: '%s'", model_path.c_str());
std::string merges_utf8_str = model_loader.load_merges();
if (merges_utf8_str.size() == 0) {
LOG_ERROR("get merges failed: '%s'", model_path.c_str());
return false;
}

cond_stage_model.tokenizer.load_from_merges(merges_utf8_str);

// create the ggml context for network params
LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));

Expand Down Expand Up @@ -4431,7 +4512,7 @@ class StableDiffusionGGML {

// print_ggml_tensor(alphas_cumprod_tensor);

success = model_loader.load_tensors(on_new_tensor_cb);
bool success = model_loader.load_tensors(on_new_tensor_cb);
if (!success) {
LOG_ERROR("load tensors from file failed");
ggml_free(ctx);
Expand Down
17 changes: 17 additions & 0 deletions util.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "util.h"

#include <stdarg.h>
#include <codecvt>
#include <fstream>
#include <locale>
#include <thread>
#include <unordered_set>
#include <vector>
Expand Down Expand Up @@ -119,6 +121,21 @@ int32_t get_num_physical_cores() {
return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4;
}

std::u32string utf8_to_utf32(const std::string& utf8_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.from_bytes(utf8_str);
}

std::string utf32_to_utf8(const std::u32string& utf32_str) {
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
return converter.to_bytes(utf32_str);
}

std::u32string unicode_value_to_utf32(int unicode_value) {
std::u32string utf32_string = {static_cast<char32_t>(unicode_value)};
return utf32_string;
}

std::string basename(const std::string& path) {
size_t pos = path.find_last_of('/');
if (pos != std::string::npos) {
Expand Down
4 changes: 4 additions & 0 deletions util.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ void replace_all_chars(std::string& str, char target, char replacement);
bool file_exists(const std::string& filename);
bool is_directory(const std::string& path);

std::u32string utf8_to_utf32(const std::string& utf8_str);
std::string utf32_to_utf8(const std::u32string& utf32_str);
std::u32string unicode_value_to_utf32(int unicode_value);

std::string basename(const std::string& path);

std::string path_join(const std::string& p1, const std::string& p2);
Expand Down
Loading

0 comments on commit 8f6b4a3

Please sign in to comment.