From 78ad76f3f49a15692a550e6f73ad3a5765ffdb25 Mon Sep 17 00:00:00 2001 From: leejet Date: Fri, 29 Dec 2023 00:16:10 +0800 Subject: [PATCH] feat: add SDXL support (#117) * add SDXL support * fix the issue with generating large images --- .gitmodules | 2 +- README.md | 5 +- ggml | 2 +- model.cpp | 48 ++- stable-diffusion.cpp | 959 ++++++++++++++++++++++++++++--------------- 5 files changed, 669 insertions(+), 347 deletions(-) diff --git a/.gitmodules b/.gitmodules index 0b8fe290..d5788ea4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,3 @@ [submodule "ggml"] path = ggml - url = https://github.com/FSSRepo/ggml.git + url = https://github.com/leejet/ggml.git diff --git a/README.md b/README.md index cc6938bf..feec44ad 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,8 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp) - Super lightweight and without external dependencies -- SD1.x and SD2.x support -- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) support +- SD1.x, SD2.x and SDXL support +- [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support - 16-bit, 32-bit float support - 4-bit, 5-bit and 8-bit integer quantization support - Accelerated memory-efficient CPU inference @@ -302,3 +302,4 @@ Thank you to all the people who have already contributed to stable-diffusion.cpp - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) - [k-diffusion](https://github.com/crowsonkb/k-diffusion) - [latent-consistency-model](https://github.com/luosiallen/latent-consistency-model) +- [generative-models](https://github.com/Stability-AI/generative-models/) diff --git a/ggml b/ggml index a0c2ec77..e5d3412f 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit a0c2ec77a5ef8e630aff65bc535d13b9805cb929 +Subproject commit e5d3412fa2ea3de8c4a696c03dce73c470442dc1 diff --git a/model.cpp b/model.cpp index f8f0752c..01b89030 100644 --- a/model.cpp +++ b/model.cpp @@ -78,8 +78,9 @@ const char* unused_tensors[] = { "cond_stage_model.transformer.text_model.embeddings.position_ids", "cond_stage_model.model.logit_scale", "cond_stage_model.model.text_projection", + "conditioner.embedders.0.transformer.text_model.embeddings.position_ids", "conditioner.embedders.0.model.logit_scale", - "conditioner.embedders.0.model.text_projection", + "conditioner.embedders.1.model.logit_scale", "model.diffusion_model.time_embedding.cond_proj.weight", "unet.time_embedding.cond_proj.weight", "model_ema.decay", @@ -100,11 +101,11 @@ bool is_unused_tensor(std::string name) { } std::unordered_map open_clip_to_hf_clip_model = { - {"cond_stage_model.model.ln_final.bias", "cond_stage_model.transformer.text_model.final_layer_norm.bias"}, - {"cond_stage_model.model.ln_final.weight", "cond_stage_model.transformer.text_model.final_layer_norm.weight"}, - {"cond_stage_model.model.positional_embedding", "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight"}, - {"cond_stage_model.model.token_embedding.weight", "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight"}, - + {"model.ln_final.bias", "transformer.text_model.final_layer_norm.bias"}, + {"model.ln_final.weight", "transformer.text_model.final_layer_norm.weight"}, + {"model.positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"}, + {"model.token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"}, + {"model.text_projection", "transformer.text_model.text_projection"}, }; std::unordered_map open_clip_to_hk_clip_resblock = { @@ -133,11 +134,21 @@ std::unordered_map vae_decoder_name_map = { std::string convert_open_clip_to_hf_clip(const std::string& name) { std::string new_name = name; + std::string prefix; if (starts_with(new_name, "conditioner.embedders.0.")) { - new_name = "cond_stage_model." + new_name.substr(strlen("conditioner.embedders.0.")); + prefix = "cond_stage_model."; + new_name = new_name.substr(strlen("conditioner.embedders.0.")); + } else if (starts_with(new_name, "conditioner.embedders.1.")) { + prefix = "cond_stage_model.1."; + new_name = new_name.substr(strlen("conditioner.embedders.0.")); + } else if (starts_with(new_name, "cond_stage_model.")) { + prefix = "cond_stage_model."; + new_name = new_name.substr(strlen("cond_stage_model.")); + } else { + return new_name; } - std::string open_clip_resblock_prefix = "cond_stage_model.model.transformer.resblocks."; - std::string hf_clip_resblock_prefix = "cond_stage_model.transformer.text_model.encoder.layers."; + std::string open_clip_resblock_prefix = "model.transformer.resblocks."; + std::string hf_clip_resblock_prefix = "transformer.text_model.encoder.layers."; if (open_clip_to_hf_clip_model.find(new_name) != open_clip_to_hf_clip_model.end()) { new_name = open_clip_to_hf_clip_model[new_name]; @@ -156,7 +167,7 @@ std::string convert_open_clip_to_hf_clip(const std::string& name) { } } - return new_name; + return prefix + new_name; } std::string convert_vae_decoder_name(const std::string& name) { @@ -358,7 +369,7 @@ std::string convert_diffusers_name_to_compvis(const std::string& key, char seq) std::string convert_tensor_name(const std::string& name) { std::string new_name; - if (starts_with(name, "cond_stage_model.model") || starts_with(name, "conditioner.embedders.0.model")) { + if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.")) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); @@ -419,7 +430,7 @@ void preprocess_tensor(TensorStorage tensor_storage, tensor_storage.name = new_name; - if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") && + if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) { size_t prefix_size = new_name.find("attn.in_proj_weight"); std::string prefix = new_name.substr(0, prefix_size); @@ -431,7 +442,7 @@ void preprocess_tensor(TensorStorage tensor_storage, processed_tensor_storages.insert(processed_tensor_storages.end(), chunks.begin(), chunks.end()); - } else if (starts_with(new_name, "cond_stage_model.transformer.text_model.encoder.layers.") && + } else if (new_name.find("transformer.text_model.encoder.layers.") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) { size_t prefix_size = new_name.find("attn.in_proj_bias"); std::string prefix = new_name.substr(0, prefix_size); @@ -1163,15 +1174,20 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s } SDVersion ModelLoader::get_sd_version() { + // return VERSION_1_x; TensorStorage token_embedding_weight; for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { + return VERSION_XL; + } if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" || tensor_storage.name == "te.text_model.embeddings.token_embedding.weight" || - tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight") { + tensor_storage.name == "conditioner.embedders.0.model.token_embedding.weight" || + tensor_storage.name == "conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight") { token_embedding_weight = tensor_storage; - break; + // break; } } if (token_embedding_weight.ne[0] == 768) { @@ -1275,7 +1291,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } for (auto& tensor_storage : processed_tensor_storages) { - // LOG_DEBUG("%s", name.c_str()); + // LOG_DEBUG("%s", tensor_storage.name.c_str()); ggml_tensor* dst_tensor = NULL; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 70cd79a7..c1ffdc80 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -35,8 +35,8 @@ #define EPS 1e-05f -#define UNET_GRAPH_SIZE 3328 -#define LORA_GRAPH_SIZE 4096 +#define UNET_GRAPH_SIZE 10240 +#define LORA_GRAPH_SIZE 10240 #define TIMESTEPS 1000 @@ -127,6 +127,9 @@ void ggml_tensor_set_f32(struct ggml_tensor* tensor, float value, int l, int k = } float ggml_tensor_get_f32(const ggml_tensor* tensor, int l, int k = 0, int j = 0, int i = 0) { + // float value; + // ggml_backend_tensor_get(tensor, &value, i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0], sizeof(float)); + // return value; GGML_ASSERT(tensor->nb[0] == sizeof(float)); return *(float*)((char*)(tensor->data) + i * tensor->nb[3] + j * tensor->nb[2] + k * tensor->nb[1] + l * tensor->nb[0]); } @@ -276,7 +279,7 @@ void calculate_alphas_cumprod(float* alphas_cumprod, // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 void set_timestep_embedding(struct ggml_tensor* timesteps, struct ggml_tensor* embedding, int dim, int max_period = 10000) { // timesteps: [N,] - // embedding: [(dim + 1)/2, N] + // embedding: [dim, N] int half = dim / 2; std::vector freqs(half); for (int i = 0; i < half; ++i) { @@ -296,7 +299,7 @@ void set_timestep_embedding(struct ggml_tensor* timesteps, struct ggml_tensor* e struct ggml_tensor* new_timestep_embedding(struct ggml_context* ctx, struct ggml_allocr* allocr, struct ggml_tensor* timesteps, int dim, int max_period = 10000) { // timesteps: [N,] - // embedding: [(dim + 1)/2, N] + // embedding: [dim, N] int acutual_dim = dim; if (dim % 2 != 0) { acutual_dim = dim + 1; @@ -1147,7 +1150,7 @@ struct ResidualAttentionBlock { // mlp x = ggml_nn_linear(ctx, x, fc1_w, fc1_b); - if (hidden_size == 1024) { // SD 2.x + if (hidden_size == 1024 || hidden_size == 1280) { // SD 2.x x = ggml_gelu_inplace(ctx, x); } else { // SD 1.x x = ggml_gelu_quick_inplace(ctx, x); @@ -1161,21 +1164,30 @@ struct ResidualAttentionBlock { } }; -// VERSION_1_x.x: https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json -// VERSION_2_x.x: https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json -// VERSION_XL: https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/blob/main/config.json (CLIPTextModelWithProjection) +// OPENAI_CLIP_VIT_L_14: https://huggingface.co/openai/clip-vit-large-patch14/blob/main/config.json +// OPEN_CLIP_VIT_H_14: https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json +// OPEN_CLIP_VIT_BIGG_14: https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/blob/main/config.json (CLIPTextModelWithProjection) // SDXL CLIPModel // CLIPTextModelWithProjection seems optional + +enum CLIPVersion { + OPENAI_CLIP_VIT_L_14, // SD 1.x and SDXL + OPEN_CLIP_VIT_H_14, // SD 2.x + OPEN_CLIP_VIT_BIGG_14, // SDXL +}; + struct CLIPTextModel { - SDVersion version = VERSION_1_x; + CLIPVersion version = OPENAI_CLIP_VIT_L_14; // network hparams int32_t vocab_size = 49408; int32_t max_position_embeddings = 77; - int32_t hidden_size = 768; // 1024 for SD 2.x - int32_t intermediate_size = 3072; // 4096 for SD 2.x - int32_t n_head = 12; // num_attention_heads, 16 for SD 2.x - int32_t num_hidden_layers = 12; // 24 for SD 2.x - int32_t clip_skip = 1; + int32_t hidden_size = 768; // 1024 for OPEN_CLIP_VIT_H_14 + int32_t intermediate_size = 3072; // 4096 for OPEN_CLIP_VIT_H_14 + int32_t n_head = 12; // num_attention_heads, 16 for OPEN_CLIP_VIT_H_14 + int32_t num_hidden_layers = 12; // 24 for OPEN_CLIP_VIT_H_14 + int32_t layer_idx = 11; + int32_t projection_dim = 1280; // only for OPEN_CLIP_VIT_BIGG_14 + bool with_final_ln = true; // embeddings struct ggml_tensor* position_ids; @@ -1187,31 +1199,24 @@ struct CLIPTextModel { struct ggml_tensor* final_ln_w; struct ggml_tensor* final_ln_b; - // context and memory buffers - struct ggml_context* ctx; - ggml_backend_buffer_t params_buffer; - ggml_backend_buffer_t compute_buffer; // for compute - struct ggml_allocr* compute_alloc = NULL; - size_t compute_memory_buffer_size = -1; - - size_t memory_buffer_size = 0; - ggml_type wtype; - ggml_backend_t backend = NULL; - ggml_tensor* work_output = NULL; + struct ggml_tensor* text_projection; - CLIPTextModel(SDVersion version = VERSION_1_x, bool has_pool = false) - : version(version) { - if (version == VERSION_2_x) { + CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, + int clip_skip = 1, + bool with_final_ln = true) + : version(version), with_final_ln(with_final_ln) { + if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1024; intermediate_size = 4096; n_head = 16; num_hidden_layers = 24; - } else if (version == VERSION_XL && has_pool) { // CLIPTextModelWithProjection + } else if (version == OPEN_CLIP_VIT_BIGG_14) { // CLIPTextModelWithProjection hidden_size = 1280; intermediate_size = 5120; n_head = 20; num_hidden_layers = 32; } + layer_idx = num_hidden_layers - clip_skip; resblocks.resize(num_hidden_layers); set_resblocks_hp_params(); } @@ -1226,42 +1231,7 @@ struct CLIPTextModel { } } - bool initialize(ggml_backend_t backend_, ggml_type wtype_) { - backend = backend_; - wtype = wtype_; - memory_buffer_size = 1 * 1024 * 1024; // 1 MB, for padding - memory_buffer_size += calculate_mem_size(); - - int num_tensors = (3 + 2 + 37 * num_hidden_layers); - LOG_DEBUG("clip params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); - - struct ggml_init_params params; - params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); - params.mem_buffer = NULL; - params.no_alloc = true; - - ctx = ggml_init(params); - if (!ctx) { - LOG_ERROR("ggml_init() failed"); - return false; - } - params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); - return true; - } - - void destroy() { - if (ctx != NULL) { - ggml_free(ctx); - ctx = NULL; - } - - if (params_buffer != NULL) { - ggml_backend_buffer_free(params_buffer); - params_buffer = NULL; - } - } - - size_t calculate_mem_size() { + size_t calculate_mem_size(ggml_type wtype) { double mem_size = 0; mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight @@ -1270,45 +1240,10 @@ struct CLIPTextModel { mem_size += resblocks[i].calculate_mem_size(wtype); } mem_size += 2 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // final_ln_w/b - return static_cast(mem_size); - } - - void alloc_params() { - ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); - position_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, max_position_embeddings); - - token_embed_weight = ggml_new_tensor_2d(ctx, wtype, hidden_size, vocab_size); - - position_embed_weight = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings); - - for (int i = 0; i < num_hidden_layers; i++) { - resblocks[i].init_params(ctx, alloc, wtype); - } - - final_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); - - final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); - - // alloc all tensors linked to this context - for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (t->data == NULL) { - ggml_allocr_alloc(alloc, t); - } - } - - if (ggml_backend_is_cpu(backend)) { - for (int i = 0; i < max_position_embeddings; i++) { - ggml_set_i32_1d(position_ids, i, i); - } - } else { - std::vector pos_temp; - for (int i = 0; i < max_position_embeddings; i++) { - pos_temp.push_back(i); - } - ggml_backend_tensor_set(position_ids, pos_temp.data(), 0, ggml_nbytes(position_ids)); + if (version == OPEN_CLIP_VIT_BIGG_14) { + mem_size += hidden_size * projection_dim * ggml_type_sizef(GGML_TYPE_F32); // text_projection } - - ggml_allocr_free(alloc); + return static_cast(mem_size); } void map_by_name(std::map& tensors, const std::string prefix) { @@ -1317,11 +1252,15 @@ struct CLIPTextModel { tensors[prefix + "final_layer_norm.weight"] = final_ln_w; tensors[prefix + "final_layer_norm.bias"] = final_ln_b; for (int i = 0; i < num_hidden_layers; i++) { + std::string name = prefix + "encoder.layers." + std::to_string(i) + "."; resblocks[i].map_by_name(tensors, prefix + "encoder.layers." + std::to_string(i) + "."); } + if (version == OPEN_CLIP_VIT_BIGG_14) { + tensors[prefix + "text_projection"] = text_projection; + } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids) { + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, uint32_t max_token_idx = 0, bool return_pooled = false) { // input_ids: [N, n_token] GGML_ASSERT(input_ids->ne[0] <= position_ids->ne[0]); @@ -1334,101 +1273,68 @@ struct CLIPTextModel { ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size] // transformer - int layer_idx = num_hidden_layers - clip_skip; for (int i = 0; i < num_hidden_layers; i++) { - if (i == layer_idx + 1) { + if (!return_pooled && i == layer_idx + 1) { + // LOG_DEBUG("layer %d", i); break; } x = resblocks[i].forward(ctx0, x); // [N, n_token, hidden_size] } // final layer norm - x = ggml_nn_layer_norm(ctx0, x, final_ln_w, final_ln_b); + if (return_pooled || with_final_ln) { + x = ggml_nn_layer_norm(ctx0, x, final_ln_w, final_ln_b); + } + + if (return_pooled) { + // ggml_tensor* idx = ggml_argmax(ctx0, input_ids); + // ggml_tensor* pooled = ggml_get_rows(ctx0, x, idx); + // LOG_DEBUG("max_token_idx: %u %u", max_token_idx, x->nb[1]); + ggml_tensor* pooled = ggml_view_1d(ctx0, x, hidden_size, x->nb[1] * max_token_idx); + pooled = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, text_projection)), pooled); + return pooled; + } return x; // [N, n_token, hidden_size] } - struct ggml_cgraph* build_graph(struct ggml_allocr* allocr, std::vector tokens) { - // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data - static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); - static std::vector buf(buf_size); + void alloc_params(ggml_context* ctx, ggml_backend_t backend, ggml_type wtype, ggml_allocr* alloc) { + position_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, max_position_embeddings); - struct ggml_init_params params = { - /*.mem_size =*/buf_size, - /*.mem_buffer =*/buf.data(), - /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() - }; - - struct ggml_context* ctx0 = ggml_init(params); - - struct ggml_cgraph* gf = ggml_new_graph(ctx0); + token_embed_weight = ggml_new_tensor_2d(ctx, wtype, hidden_size, vocab_size); - struct ggml_tensor* input_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, tokens.size()); - ggml_allocr_alloc(allocr, input_ids); + position_embed_weight = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings); - if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); + for (int i = 0; i < num_hidden_layers; i++) { + resblocks[i].init_params(ctx, alloc, wtype); } - struct ggml_tensor* hidden_states = forward(ctx0, input_ids); - - ggml_build_forward_expand(gf, hidden_states); - ggml_free(ctx0); + final_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); - return gf; - } + final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); - void begin(ggml_context* work_ctx, int max_tokens) { - if (work_output == NULL) { - work_output = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, hidden_size, max_position_embeddings); + if (version == OPEN_CLIP_VIT_BIGG_14) { + text_projection = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); } - // calculate the amount of memory required - if (compute_memory_buffer_size == -1) { - compute_alloc = ggml_allocr_new_measure_from_backend(backend); - - struct ggml_cgraph* gf = build_graph(compute_alloc, std::vector(max_tokens)); - // compute the required memory - compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); - // recreate the allocator with the required memory - ggml_allocr_free(compute_alloc); - - LOG_DEBUG("learned condition compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + // alloc all tensors linked to this context + for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->data == NULL) { + ggml_allocr_alloc(alloc, t); + } } - compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); - compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); - } - - struct ggml_tensor* compute(const int n_threads, std::vector tokens) { - struct ggml_cgraph* gf = build_graph(compute_alloc, tokens); - - ggml_allocr_alloc_graph(compute_alloc, gf); if (ggml_backend_is_cpu(backend)) { - ggml_backend_cpu_set_n_threads(backend, n_threads); - } - -#ifdef SD_USE_METAL - if (ggml_backend_is_metal(backend)) { - ggml_backend_metal_set_n_cb(backend, n_threads); + for (int i = 0; i < max_position_embeddings; i++) { + ggml_set_i32_1d(position_ids, i, i); + } + } else { + std::vector pos_temp; + for (int i = 0; i < max_position_embeddings; i++) { + pos_temp.push_back(i); + } + ggml_backend_tensor_set(position_ids, pos_temp.data(), 0, ggml_nbytes(position_ids)); } -#endif - - ggml_backend_graph_compute(backend, gf); - -#ifdef GGML_PERF - ggml_graph_print(gf); -#endif - ggml_backend_tensor_get_and_sync(backend, gf->nodes[gf->n_nodes - 1], work_output->data, 0, ggml_nbytes(work_output)); - return work_output; - } - - void end() { - ggml_allocr_free(compute_alloc); - ggml_backend_buffer_free(compute_buffer); - compute_alloc = NULL; - compute_memory_buffer_size = -1; - work_output = NULL; } }; @@ -1451,9 +1357,90 @@ struct FrozenCLIPEmbedderWithCustomWords { SDVersion version = VERSION_1_x; CLIPTokenizer tokenizer; CLIPTextModel text_model; + CLIPTextModel text_model2; + + // context and memory buffers + struct ggml_context* ctx; + ggml_backend_buffer_t params_buffer; + ggml_backend_buffer_t compute_buffer; // for compute + struct ggml_allocr* compute_alloc = NULL; + size_t compute_memory_buffer_size = -1; + + size_t memory_buffer_size = 0; + ggml_type wtype; + ggml_backend_t backend = NULL; + ggml_tensor* hidden_state_output = NULL; + ggml_tensor* pooled_output = NULL; + + FrozenCLIPEmbedderWithCustomWords(SDVersion version = VERSION_1_x, int clip_skip = -1) + : version(version), tokenizer(version) { + if (clip_skip <= 0) { + clip_skip = 1; + if (version == VERSION_2_x || version == VERSION_XL) { + clip_skip = 2; + } + } + if (version == VERSION_1_x) { + text_model = CLIPTextModel(OPENAI_CLIP_VIT_L_14, clip_skip); + } else if (version == VERSION_2_x) { + text_model = CLIPTextModel(OPEN_CLIP_VIT_H_14, clip_skip); + } else if (version == VERSION_XL) { + text_model = CLIPTextModel(OPENAI_CLIP_VIT_L_14, clip_skip, false); + text_model2 = CLIPTextModel(OPEN_CLIP_VIT_BIGG_14, clip_skip, false); + } + } + + size_t calculate_mem_size() { + size_t mem_size = text_model.calculate_mem_size(wtype); + if (version == VERSION_XL) { + mem_size += text_model2.calculate_mem_size(wtype); + } + return mem_size; + } + + void map_by_name(std::map& tensors, const std::string prefix) { + text_model.map_by_name(tensors, prefix + "transformer.text_model."); + if (version == VERSION_XL) { + text_model2.map_by_name(tensors, prefix + "1.transformer.text_model."); + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, uint32_t max_token_idx = 0, bool return_pooled = false) { + if (return_pooled) { + return text_model2.forward(ctx0, input_ids2, max_token_idx, return_pooled); + } + auto hidden_states = text_model.forward(ctx0, input_ids); // [N, n_token, hidden_size] + // LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); + if (version == VERSION_XL) { + hidden_states = ggml_reshape_4d(ctx0, + hidden_states, + hidden_states->ne[0], + hidden_states->ne[1], + hidden_states->ne[2], + hidden_states->ne[3]); + hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 2, 0, 1, 3)); + + auto hidden_states2 = text_model2.forward(ctx0, input_ids2); // [N, n_token, hidden_size2] + hidden_states2 = ggml_reshape_4d(ctx0, + hidden_states2, + hidden_states2->ne[0], + hidden_states2->ne[1], + hidden_states2->ne[2], + hidden_states2->ne[3]); + hidden_states2 = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states2, 2, 0, 1, 3)); - FrozenCLIPEmbedderWithCustomWords(SDVersion version = VERSION_1_x) - : version(version), tokenizer(version), text_model(version) {} + hidden_states = ggml_concat(ctx0, hidden_states, hidden_states2); // [N, n_token, hidden_size + hidden_size2] + + hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 1, 2, 0, 3)); + } + // LOG_DEBUG("hidden_states: %d %d %d %d", hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); + return hidden_states; + } + + std::pair, std::vector> tokenize(std::string text, + bool padding = false) { + return tokenize(text, text_model.max_position_embeddings, padding); + } std::pair, std::vector> tokenize(std::string text, size_t max_length = 0, @@ -1509,6 +1496,187 @@ struct FrozenCLIPEmbedderWithCustomWords { return {tokens, weights}; } + + bool initialize(ggml_backend_t backend_, ggml_type wtype_) { + backend = backend_; + wtype = wtype_; + memory_buffer_size = 1 * 1024 * 1024; // 1 MB, for padding + memory_buffer_size += calculate_mem_size(); + + int num_tensors = (3 + 2 + 37 * text_model.num_hidden_layers); + if (version == VERSION_XL) { + num_tensors += (3 + 2 + 37 * text_model2.num_hidden_layers); + } + LOG_DEBUG("clip params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); + + struct ggml_init_params params; + params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); + params.mem_buffer = NULL; + params.no_alloc = true; + + ctx = ggml_init(params); + if (!ctx) { + LOG_ERROR("ggml_init() failed"); + return false; + } + params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); + return true; + } + + void destroy() { + if (ctx != NULL) { + ggml_free(ctx); + ctx = NULL; + } + + if (params_buffer != NULL) { + ggml_backend_buffer_free(params_buffer); + params_buffer = NULL; + } + } + + void alloc_params() { + ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); + text_model.alloc_params(ctx, backend, wtype, alloc); + if (version == VERSION_XL) { + text_model2.alloc_params(ctx, backend, wtype, alloc); + } + ggml_allocr_free(alloc); + } + + struct ggml_cgraph* build_graph(struct ggml_allocr* allocr, std::vector tokens, bool return_pooled = false) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + struct ggml_context* ctx0 = ggml_init(params); + + struct ggml_cgraph* gf = ggml_new_graph(ctx0); + + struct ggml_tensor* input_ids = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, tokens.size()); + ggml_allocr_alloc(allocr, input_ids); + + if (!ggml_allocr_is_measure(allocr)) { + ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); + } + + struct ggml_tensor* input_ids2 = NULL; + size_t max_token_idx = 0; + if (version == VERSION_XL) { + input_ids2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, tokens.size()); + ggml_allocr_alloc(allocr, input_ids2); + + auto it = std::find(tokens.begin(), tokens.end(), EOS_TOKEN_ID); + if (it != tokens.end()) { + std::fill(std::next(it), tokens.end(), 0); + } + + max_token_idx = std::min(std::distance(tokens.begin(), it), tokens.size() - 1); + + // for (int i = 0; i < tokens.size(); i++) { + // printf("%d ", tokens[i]); + // } + // printf("\n"); + + if (!ggml_allocr_is_measure(allocr)) { + ggml_backend_tensor_set(input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); + } + } + + struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, max_token_idx, return_pooled); + + ggml_build_forward_expand(gf, hidden_states); + ggml_free(ctx0); + + return gf; + } + + void begin(ggml_context* work_ctx, int max_tokens) { + if (hidden_state_output == NULL) { + size_t total_hidden_size = text_model.hidden_size; + if (version == VERSION_XL) { + total_hidden_size += text_model2.hidden_size; + } + hidden_state_output = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, total_hidden_size, text_model.max_position_embeddings); + pooled_output = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, text_model2.projection_dim); + } + // calculate the amount of memory required + if (compute_memory_buffer_size == -1) { + compute_alloc = ggml_allocr_new_measure_from_backend(backend); + + bool return_pooled = false; + if (version == VERSION_XL) { + return_pooled = true; + } + struct ggml_cgraph* gf = build_graph(compute_alloc, std::vector(max_tokens), return_pooled); + // compute the required memory + compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf) + 1024 * 1024; + + // recreate the allocator with the required memory + ggml_allocr_free(compute_alloc); + + LOG_DEBUG("learned condition compute buffer size: %.2f MB", compute_memory_buffer_size / 1024.0 / 1024.0); + } + compute_buffer = ggml_backend_alloc_buffer(backend, compute_memory_buffer_size); + compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); + } + + std::pair compute(const int n_threads, std::vector tokens) { + struct ggml_cgraph* gf = build_graph(compute_alloc, tokens); + + ggml_allocr_alloc_graph(compute_alloc, gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + +#ifdef SD_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + + ggml_backend_graph_compute(backend, gf); + +#ifdef GGML_PERF + ggml_graph_print(gf); +#endif + ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], hidden_state_output->data, 0, ggml_nbytes(hidden_state_output)); + + if (version == VERSION_XL) { + struct ggml_cgraph* gf = build_graph(compute_alloc, tokens, true); + + ggml_allocr_alloc_graph(compute_alloc, gf); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } + + ggml_backend_graph_compute(backend, gf); + +#ifdef GGML_PERF + ggml_graph_print(gf); +#endif + ggml_backend_tensor_get(gf->nodes[gf->n_nodes - 1], pooled_output->data, 0, ggml_nbytes(pooled_output)); + return {hidden_state_output, pooled_output}; + } + return {hidden_state_output, NULL}; + } + + void end() { + ggml_allocr_free(compute_alloc); + ggml_backend_buffer_free(compute_buffer); + compute_alloc = NULL; + compute_memory_buffer_size = -1; + hidden_state_output = NULL; + pooled_output = NULL; + } }; /*==================================================== UnetModel =====================================================*/ @@ -1637,7 +1805,7 @@ struct SpatialTransformer { int n_head; // num_heads int d_head; // in_channels // n_heads int depth = 1; // 1 - int context_dim = 768; // hidden_size, 1024 for VERSION_2_x.x + int context_dim = 768; // hidden_size, 1024 for VERSION_2_x // group norm struct ggml_tensor* norm_w; // [in_channels,] @@ -1648,8 +1816,7 @@ struct SpatialTransformer { struct ggml_tensor* proj_in_b; // [in_channels,] // transformer - struct - { + struct Transformer { // layer norm 1 struct ggml_tensor* norm1_w; // [in_channels, ] struct ggml_tensor* norm1_b; // [in_channels, ] @@ -1684,7 +1851,9 @@ struct SpatialTransformer { struct ggml_tensor* ff_2_w; // [in_channels, in_channels * 4] struct ggml_tensor* ff_2_b; // [in_channels,] - } transformer; // supposes depth = 1, this need to be a list + }; + + std::vector transformers; struct ggml_tensor* attn_scale; @@ -1692,6 +1861,15 @@ struct SpatialTransformer { struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1] struct ggml_tensor* proj_out_b; // [in_channels,] + SpatialTransformer(int depth = 1) + : depth(depth) { + transformers.resize(depth); + } + + size_t get_num_tensors() { + return depth * 20 + 7; + } + size_t calculate_mem_size(ggml_type wtype) { double mem_size = 0; mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b @@ -1700,7 +1878,7 @@ struct SpatialTransformer { mem_size += 1 * ggml_type_sizef(GGML_TYPE_F32); // attn_scale // transformer - { + for (auto& transformer : transformers) { mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1-3_w/b mem_size += 6 * in_channels * in_channels * ggml_type_sizef(wtype); // attn1_q/k/v/out_w attn2_q/out_w mem_size += 2 * in_channels * context_dim * ggml_type_sizef(wtype); // attn2_k/v_w @@ -1727,34 +1905,36 @@ struct SpatialTransformer { ggml_backend_tensor_set(attn_scale, &scale, 0, sizeof(scale)); // transformer - transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + for (auto& transformer : transformers) { + transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); - transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); + transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); + transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); - transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2); - transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2); + transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2); + transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2); - transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels); - transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels); + transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + } } void map_by_name(std::map& tensors, const std::string prefix) { @@ -1764,8 +1944,9 @@ struct SpatialTransformer { tensors[prefix + "proj_in.bias"] = proj_in_b; // transformer - { - std::string transformer_prefix = prefix + "transformer_blocks.0."; // to admit depth > 1 this must be "transformer_blocks.%i" (SDXL) + for (int i = 0; i < transformers.size(); i++) { + auto& transformer = transformers[i]; + std::string transformer_prefix = prefix + "transformer_blocks." + std::to_string(i) + "."; tensors[transformer_prefix + "attn1.to_q.weight"] = transformer.attn1_q_w; tensors[transformer_prefix + "attn1.to_k.weight"] = transformer.attn1_k_w; tensors[transformer_prefix + "attn1.to_v.weight"] = transformer.attn1_v_w; @@ -1813,7 +1994,7 @@ struct SpatialTransformer { const int64_t max_position = context->ne[1]; x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, in_channels] - { + for (auto& transformer : transformers) { auto r = x; // layer norm 1 x = ggml_reshape_2d(ctx, x, c, w * h * n); @@ -1954,6 +2135,7 @@ struct SpatialTransformer { // residual x = ggml_add(ctx, x, r); } + x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w] // proj_out @@ -2046,17 +2228,20 @@ struct UpSample { // ldm.modules.diffusionmodules.openaimodel.UNetModel struct UNetModel { + SDVersion version = VERSION_1_x; // network hparams - int in_channels = 4; - int model_channels = 320; - int out_channels = 4; - int num_res_blocks = 2; - int attention_resolutions[3] = {4, 2, 1}; - int channel_mult[4] = {1, 2, 4, 4}; - int time_embed_dim = 1280; // model_channels*4 - int num_heads = 8; - int num_head_channels = -1; // channels // num_heads - int context_dim = 768; // 1024 for VERSION_2_x.x + int in_channels = 4; + int model_channels = 320; + int out_channels = 4; + int num_res_blocks = 2; + std::vector attention_resolutions = {4, 2, 1}; + std::vector channel_mult = {1, 2, 4, 4}; + std::vector transformer_depth = {1, 1, 1, 1}; + int time_embed_dim = 1280; // model_channels*4 + int num_heads = 8; + int num_head_channels = -1; // channels // num_heads + int context_dim = 768; // 1024 for VERSION_2_x, 2048 for VERSION_XL + int adm_in_channels = 2816; // only for VERSION_XL // network params struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels] @@ -2065,6 +2250,12 @@ struct UNetModel { struct ggml_tensor* time_embed_2_w; // [time_embed_dim, time_embed_dim] struct ggml_tensor* time_embed_2_b; // [time_embed_dim, ] + struct ggml_tensor* label_embed_0_w; // [time_embed_dim, adm_in_channels] + struct ggml_tensor* label_embed_0_b; // [time_embed_dim, ] + // label_embed_1 is nn.SILU() + struct ggml_tensor* label_embed_2_w; // [time_embed_dim, time_embed_dim] + struct ggml_tensor* label_embed_2_b; // [time_embed_dim, ] + struct ggml_tensor* input_block_0_w; // [model_channels, in_channels, 3, 3] struct ggml_tensor* input_block_0_b; // [model_channels, ] @@ -2101,27 +2292,19 @@ struct UNetModel { ggml_type wtype; ggml_backend_t backend = NULL; - UNetModel(SDVersion version = VERSION_1_x) { - // transformer_depth size is the same of channel_mult size - // transformer_depth = {1, 1, 1, 0} - // transformer_depth[index of channel_mult] is applied to SpatialTransformer.depth var - // transformer_depth_middle = 1 default - - // adm_in_channels = -1 (none) + UNetModel(SDVersion version = VERSION_1_x) + : version(version) { if (version == VERSION_2_x) { context_dim = 1024; num_head_channels = 64; num_heads = -1; } else if (version == VERSION_XL) { - context_dim = 2048; - // attention_resolutions = {4, 2} - // channel_mult = {1, 2, 4} - // transformer_depth = {0, 2, 10} - // transformer_depth_middle = 10 - // adm_in_channels = 2816 - // requieres a Sequential phase as "time_embed": label_emb - num_head_channels = 64; - num_heads = -1; + context_dim = 2048; + attention_resolutions = {4, 2}; + channel_mult = {1, 2, 4}; + transformer_depth = {1, 2, 10}; + num_head_channels = 64; + num_heads = -1; } // set up hparams of blocks @@ -2131,7 +2314,7 @@ struct UNetModel { int ch = model_channels; int ds = 1; - int len_mults = sizeof(channel_mult) / sizeof(int); + int len_mults = channel_mult.size(); for (int i = 0; i < len_mults; i++) { int mult = channel_mult[i]; for (int j = 0; j < num_res_blocks; j++) { @@ -2140,14 +2323,14 @@ struct UNetModel { input_res_blocks[i][j].out_channels = mult * model_channels; ch = mult * model_channels; - - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { int n_head = num_heads; int d_head = ch / num_heads; if (num_head_channels != -1) { d_head = num_head_channels; n_head = ch / d_head; } + input_transformers[i][j] = SpatialTransformer(transformer_depth[i]); input_transformers[i][j].in_channels = ch; input_transformers[i][j].n_head = n_head; input_transformers[i][j].d_head = d_head; @@ -2175,6 +2358,7 @@ struct UNetModel { d_head = num_head_channels; n_head = ch / d_head; } + middle_block_1 = SpatialTransformer(transformer_depth[transformer_depth.size() - 1]); middle_block_1.in_channels = ch; middle_block_1.n_head = n_head; middle_block_1.d_head = d_head; @@ -2197,13 +2381,14 @@ struct UNetModel { ch = mult * model_channels; - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { int n_head = num_heads; int d_head = ch / num_heads; if (num_head_channels != -1) { d_head = num_head_channels; n_head = ch / d_head; } + output_transformers[i][j] = SpatialTransformer(transformer_depth[i]); output_transformers[i][j].in_channels = ch; output_transformers[i][j].n_head = n_head; output_transformers[i][j].d_head = d_head; @@ -2227,16 +2412,23 @@ struct UNetModel { mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // time_embed_2_w mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_2_b + if (version == VERSION_XL) { + mem_size += time_embed_dim * adm_in_channels * ggml_type_sizef(wtype); // label_embed_0_w + mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // label_embed_0_b + mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // label_embed_2_w + mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // label_embed_2_b + } + mem_size += model_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // input_block_0_w mem_size += model_channels * ggml_type_sizef(GGML_TYPE_F32); // input_block_0_b // input_blocks int ds = 1; - int len_mults = sizeof(channel_mult) / sizeof(int); + int len_mults = channel_mult.size(); for (int i = 0; i < len_mults; i++) { for (int j = 0; j < num_res_blocks; j++) { mem_size += input_res_blocks[i][j].calculate_mem_size(wtype); - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { mem_size += input_transformers[i][j].calculate_mem_size(wtype); } } @@ -2256,7 +2448,7 @@ struct UNetModel { for (int j = 0; j < num_res_blocks + 1; j++) { mem_size += output_res_blocks[i][j].calculate_mem_size(wtype); - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { mem_size += output_transformers[i][j].calculate_mem_size(wtype); } @@ -2279,15 +2471,18 @@ struct UNetModel { int get_num_tensors() { // in int num_tensors = 6; + if (version == VERSION_XL) { + num_tensors += 4; + } // input blocks int ds = 1; - int len_mults = sizeof(channel_mult) / sizeof(int); + int len_mults = channel_mult.size(); for (int i = 0; i < len_mults; i++) { for (int j = 0; j < num_res_blocks; j++) { num_tensors += 12; - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { - num_tensors += 27; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + num_tensors += input_transformers[i][j].get_num_tensors(); } } if (i != len_mults - 1) { @@ -2297,15 +2492,16 @@ struct UNetModel { } // middle blocks - num_tensors += 13 * 3; + num_tensors += 13 * 2; + num_tensors += middle_block_1.get_num_tensors(); // output blocks for (int i = len_mults - 1; i >= 0; i--) { for (int j = 0; j < num_res_blocks + 1; j++) { num_tensors += 12; - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { - num_tensors += 27; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + num_tensors += output_transformers[i][j].get_num_tensors(); } if (i > 0 && j == num_res_blocks) { @@ -2324,16 +2520,17 @@ struct UNetModel { bool initialize(ggml_backend_t backend_, ggml_type wtype_) { backend = backend_; wtype = wtype_; - memory_buffer_size = 1 * 1024 * 1024; // 1 MB, for padding + memory_buffer_size = 10 * 1024 * 1024; // 10 MB, for padding memory_buffer_size += calculate_mem_size(); int num_tensors = get_num_tensors(); LOG_DEBUG("unet params backend buffer size = % 6.2f MB (%i tensors)", memory_buffer_size / (1024.0 * 1024.0), num_tensors); struct ggml_init_params params; - params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); + params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()) + 1 * 1024 * 1024; params.mem_buffer = NULL; params.no_alloc = true; + // LOG_DEBUG("mem_size %u ", params.mem_size); ctx = ggml_init(params); if (!ctx) { @@ -2365,21 +2562,23 @@ struct UNetModel { time_embed_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); // SDXL - // label_embed_0_w = ggml_new_tensor_2d(ctx, wtype, time_embed_dim, adm_in_channels); - // label_embed_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); - // label_embed_2_w = ggml_new_tensor_2d(ctx, wtype, time_embed_dim, time_embed_dim); - // label_embed_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); + if (version == VERSION_XL) { + label_embed_0_w = ggml_new_tensor_2d(ctx, wtype, adm_in_channels, time_embed_dim); + label_embed_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); + label_embed_2_w = ggml_new_tensor_2d(ctx, wtype, time_embed_dim, time_embed_dim); + label_embed_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, time_embed_dim); + } // input_blocks input_block_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, in_channels, model_channels); input_block_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model_channels); int ds = 1; - int len_mults = sizeof(channel_mult) / sizeof(int); + int len_mults = channel_mult.size(); for (int i = 0; i < len_mults; i++) { for (int j = 0; j < num_res_blocks; j++) { input_res_blocks[i][j].init_params(ctx, wtype); - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { input_transformers[i][j].init_params(ctx, alloc, wtype); } } @@ -2399,7 +2598,7 @@ struct UNetModel { for (int j = 0; j < num_res_blocks + 1; j++) { output_res_blocks[i][j].init_params(ctx, wtype); - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { output_transformers[i][j].init_params(ctx, alloc, wtype); } @@ -2431,22 +2630,28 @@ struct UNetModel { void map_by_name(std::map& tensors, const std::string prefix) { tensors[prefix + "time_embed.0.weight"] = time_embed_0_w; tensors[prefix + "time_embed.0.bias"] = time_embed_0_b; - tensors[prefix + "time_embed.2.weight"] = time_embed_2_w; tensors[prefix + "time_embed.2.bias"] = time_embed_2_b; + if (version == VERSION_XL) { + tensors[prefix + "label_emb.0.0.weight"] = label_embed_0_w; + tensors[prefix + "label_emb.0.0.bias"] = label_embed_0_b; + tensors[prefix + "label_emb.0.2.weight"] = label_embed_2_w; + tensors[prefix + "label_emb.0.2.bias"] = label_embed_2_b; + } + // input_blocks tensors[prefix + "input_blocks.0.0.weight"] = input_block_0_w; tensors[prefix + "input_blocks.0.0.bias"] = input_block_0_b; - int len_mults = sizeof(channel_mult) / sizeof(int); + int len_mults = channel_mult.size(); int input_block_idx = 0; int ds = 1; for (int i = 0; i < len_mults; i++) { for (int j = 0; j < num_res_blocks; j++) { input_block_idx += 1; input_res_blocks[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { input_transformers[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".1."); } } @@ -2469,7 +2674,7 @@ struct UNetModel { output_res_blocks[i][j].map_by_name(tensors, prefix + "output_blocks." + std::to_string(output_block_idx) + ".0."); int up_sample_idx = 1; - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { output_transformers[i][j].map_by_name(tensors, prefix + "output_blocks." + std::to_string(output_block_idx) + ".1."); up_sample_idx++; } @@ -2494,11 +2699,13 @@ struct UNetModel { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL) { + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { // x: [N, in_channels, h, w] // timesteps: [N, ] // t_emb: [N, model_channels] // context: [N, max_position, hidden_size]([N, 77, 768]) + // y: [adm_in_channels] if (t_emb == NULL && timesteps != NULL) { t_emb = new_timestep_embedding(ctx0, compute_alloc, timesteps, model_channels); // [N, model_channels] } @@ -2506,20 +2713,15 @@ struct UNetModel { // time_embed = nn.Sequential auto emb = ggml_nn_linear(ctx0, t_emb, time_embed_0_w, time_embed_0_b); emb = ggml_silu_inplace(ctx0, emb); - // Linear - emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] + emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] // SDXL - // label_emd = nn.Sequential - // Linear - // param y: an [N] Tensor of labels, if class-conditional. (clip g) - - // if(y != NULL) { - // auto y_emb = ggml_nn_linear(ctx, y, label_embed_0_w, label_embed_0_b); - // y_emb = ggml_silu_inplace(ctx, y_emb); - // y_emb = ggml_nn_linear(ctx, y_emb, label_embed_2_w, label_embed_2_b); - // emb = ggml_add(ctx, emb, y_emb); - // } + if (y != NULL) { + auto label_emb = ggml_nn_linear(ctx0, y, label_embed_0_w, label_embed_0_b); + label_emb = ggml_silu_inplace(ctx0, label_emb); + label_emb = ggml_nn_linear(ctx0, label_emb, label_embed_2_w, label_embed_2_b); + emb = ggml_add(ctx, emb, label_emb); // [N, time_embed_dim] + } // input_blocks std::vector hs; @@ -2530,13 +2732,13 @@ struct UNetModel { ggml_set_name(h, "bench-start"); hs.push_back(h); // input block 1-11 - int len_mults = sizeof(channel_mult) / sizeof(int); + int len_mults = channel_mult.size(); int ds = 1; for (int i = 0; i < len_mults; i++) { int mult = channel_mult[i]; for (int j = 0; j < num_res_blocks; j++) { h = input_res_blocks[i][j].forward(ctx0, h, emb); // [N, mult*model_channels, h, w] - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] } hs.push_back(h); @@ -2563,7 +2765,7 @@ struct UNetModel { h = ggml_concat(ctx0, h, h_skip); h = output_res_blocks[i][j].forward(ctx0, h, emb); - if (ds == attention_resolutions[0] || ds == attention_resolutions[1] || ds == attention_resolutions[2]) { + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { h = output_transformers[i][j].forward(ctx0, h, context); } @@ -2588,7 +2790,8 @@ struct UNetModel { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL) { + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -2598,6 +2801,7 @@ struct UNetModel { /*.mem_buffer =*/buf.data(), /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() }; + // LOG_DEBUG("mem_size %u ", params.mem_size); struct ggml_context* ctx0 = ggml_init(params); @@ -2608,6 +2812,7 @@ struct UNetModel { struct ggml_tensor* timesteps_t = NULL; struct ggml_tensor* context_t = NULL; struct ggml_tensor* t_emb_t = NULL; + struct ggml_tensor* y_t = NULL; // it's performing a compute, check if backend isn't cpu if (!ggml_backend_is_cpu(backend)) { @@ -2624,6 +2829,10 @@ struct UNetModel { t_emb_t = ggml_dup_tensor(ctx0, t_emb); ggml_allocr_alloc(compute_alloc, t_emb_t); } + if (y != NULL) { + y_t = ggml_dup_tensor(ctx0, y); + ggml_allocr_alloc(compute_alloc, y_t); + } // pass data to device backend if (!ggml_allocr_is_measure(compute_alloc)) { ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); @@ -2634,6 +2843,9 @@ struct UNetModel { if (t_emb_t != NULL) { ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } + if (y != NULL) { + ggml_backend_tensor_set(y_t, y->data, 0, ggml_nbytes(y)); + } } } else { // if it's cpu backend just pass the same tensors @@ -2641,9 +2853,10 @@ struct UNetModel { timesteps_t = timesteps; context_t = context; t_emb_t = t_emb; + y_t = y; } - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, t_emb_t); + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); @@ -2653,12 +2866,13 @@ struct UNetModel { void begin(struct ggml_tensor* x, struct ggml_tensor* context, - struct ggml_tensor* t_emb = NULL) { + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { if (compute_memory_buffer_size == -1) { // alignment required by the backend compute_alloc = ggml_allocr_new_measure_from_backend(backend); - struct ggml_cgraph* gf = build_graph(x, NULL, context, t_emb); + struct ggml_cgraph* gf = build_graph(x, NULL, context, t_emb, y); // compute the required memory compute_memory_buffer_size = ggml_allocr_alloc_graph(compute_alloc, gf); @@ -2673,11 +2887,17 @@ struct UNetModel { compute_alloc = ggml_allocr_new_from_buffer(compute_buffer); } - void compute(struct ggml_tensor* work_latent, int n_threads, struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, struct ggml_tensor* t_emb = NULL) { + void compute(struct ggml_tensor* work_latent, + int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL, + struct ggml_tensor* y = NULL) { ggml_allocr_reset(compute_alloc); // compute - struct ggml_cgraph* gf = build_graph(x, timesteps, context, t_emb); + struct ggml_cgraph* gf = build_graph(x, timesteps, context, t_emb, y); ggml_allocr_alloc_graph(compute_alloc, gf); @@ -3357,6 +3577,7 @@ struct AutoEncoderKL { params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); params.mem_buffer = NULL; params.no_alloc = true; + // LOG_DEBUG("mem_size %u ", params.mem_size); params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); @@ -3430,7 +3651,7 @@ struct AutoEncoderKL { struct ggml_cgraph* build_graph(struct ggml_tensor* z, bool decode_graph) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data - static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); struct ggml_init_params params = { @@ -3438,6 +3659,7 @@ struct AutoEncoderKL { /*.mem_buffer =*/buf.data(), /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() }; + // LOG_DEBUG("mem_size %u ", params.mem_size); struct ggml_context* ctx0 = ggml_init(params); @@ -4001,6 +4223,7 @@ struct TinyAutoEncoder { params.mem_size = static_cast(num_tensors * ggml_tensor_overhead()); params.mem_buffer = NULL; params.no_alloc = true; + // LOG_DEBUG("mem_size %u ", params.mem_size); params_buffer = ggml_backend_alloc_buffer(backend, memory_buffer_size); @@ -4121,6 +4344,7 @@ struct TinyAutoEncoder { /*.mem_buffer =*/buf.data(), /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() }; + // LOG_DEBUG("mem_size %u ", params.mem_size); struct ggml_context* ctx0 = ggml_init(params); @@ -4756,9 +4980,10 @@ struct LoraModel { } struct ggml_init_params params; - params.mem_size = static_cast(1024 * ggml_tensor_overhead()); + params.mem_size = static_cast(LORA_GRAPH_SIZE * ggml_tensor_overhead()); params.mem_buffer = NULL; params.no_alloc = true; + // LOG_DEBUG("mem_size %u ", params.mem_size); ctx = ggml_init(params); if (!ctx) { @@ -4805,6 +5030,7 @@ struct LoraModel { /*.mem_buffer =*/buf.data(), /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() }; + // LOG_DEBUG("mem_size %u ", params.mem_size); struct ggml_context* ctx0 = ggml_init(params); struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, LORA_GRAPH_SIZE, false); @@ -5070,6 +5296,7 @@ struct CompVisVDenoiser : public Denoiser { class StableDiffusionGGML { public: + SDVersion version; bool vae_decode_only = false; bool free_params_immediately = false; @@ -5123,7 +5350,7 @@ class StableDiffusionGGML { } ~StableDiffusionGGML() { - cond_stage_model.text_model.destroy(); + cond_stage_model.destroy(); diffusion_model.destroy(); if (!use_tiny_autoencoder) { first_stage_model.destroy(); @@ -5171,20 +5398,17 @@ class StableDiffusionGGML { } } - SDVersion version = model_loader.get_sd_version(); + version = model_loader.get_sd_version(); if (version == VERSION_COUNT) { LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); return false; } - if (clip_skip <= 0) { - clip_skip = 1; - if (version == VERSION_2_x) { - clip_skip = 2; - } + if (version == VERSION_XL) { + scale_factor = 0.13025f; } - cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version); - cond_stage_model.text_model.clip_skip = clip_skip; - diffusion_model = UNetModel(version); + cond_stage_model = FrozenCLIPEmbedderWithCustomWords(version, clip_skip); + diffusion_model = UNetModel(version); + LOG_INFO("Stable Diffusion %s ", model_version_to_str[version]); if (wtype == GGML_TYPE_COUNT) { model_data_type = model_loader.get_sd_wtype(); @@ -5206,12 +5430,17 @@ class StableDiffusionGGML { LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); if ( - !cond_stage_model.text_model.initialize(backend, model_data_type) || + !cond_stage_model.initialize(backend, model_data_type) || !diffusion_model.initialize(backend, model_data_type)) { return false; } - if (!use_tiny_autoencoder && !first_stage_model.initialize(backend, model_data_type)) { + ggml_type vae_type = model_data_type; + if (version == VERSION_XL) { + vae_type = GGML_TYPE_F32; // avoid nan, not work... + } + + if (!use_tiny_autoencoder && !first_stage_model.initialize(backend, vae_type)) { return false; } @@ -5219,8 +5448,8 @@ class StableDiffusionGGML { // prepare memory for the weights { // cond_stage_model(FrozenCLIPEmbedder) - cond_stage_model.text_model.alloc_params(); - cond_stage_model.text_model.map_by_name(tensors, "cond_stage_model.transformer.text_model."); + cond_stage_model.alloc_params(); + cond_stage_model.map_by_name(tensors, "cond_stage_model."); // diffusion_model(UNetModel) diffusion_model.alloc_params(); @@ -5234,9 +5463,10 @@ class StableDiffusionGGML { } struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024) * 1024; // 10M - params.mem_buffer = NULL; - params.no_alloc = false; + params.mem_size = static_cast(10 * 1024) * 1024; // 10M + params.mem_buffer = NULL; + params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); struct ggml_context* ctx = ggml_init(params); // for alphas_cumprod and is_using_v_parameterization check if (!ctx) { LOG_ERROR("ggml_init() failed"); @@ -5337,12 +5567,12 @@ class StableDiffusionGGML { LOG_DEBUG("model size = %.2fMB", total_size / 1024.0 / 1024.0); size_t total_params_size = - cond_stage_model.text_model.memory_buffer_size + + cond_stage_model.memory_buffer_size + diffusion_model.memory_buffer_size + first_stage_model.memory_buffer_size; LOG_INFO("total memory buffer size = %.2fMB (clip %.2fMB, unet %.2fMB, vae %.2fMB)", total_params_size / 1024.0 / 1024.0, - cond_stage_model.text_model.memory_buffer_size / 1024.0 / 1024.0, + cond_stage_model.memory_buffer_size / 1024.0 / 1024.0, diffusion_model.memory_buffer_size / 1024.0 / 1024.0, first_stage_model.memory_buffer_size / 1024.0 / 1024.0); int64_t t1 = ggml_time_ms(); @@ -5493,16 +5723,21 @@ class StableDiffusionGGML { curr_lora_state = lora_state; } - ggml_tensor* get_learned_condition(ggml_context* work_ctx, const std::string& text) { - auto tokens_and_weights = cond_stage_model.tokenize(text, - cond_stage_model.text_model.max_position_embeddings, - true); + std::pair get_learned_condition(ggml_context* work_ctx, const std::string& text, int width, int height, bool force_zero_embeddings = false) { + auto tokens_and_weights = cond_stage_model.tokenize(text, true); std::vector& tokens = tokens_and_weights.first; std::vector& weights = tokens_and_weights.second; int64_t t0 = ggml_time_ms(); - cond_stage_model.text_model.begin(work_ctx, (int)tokens.size()); - struct ggml_tensor* hidden_states = cond_stage_model.text_model.compute(n_threads, tokens); // [N, n_token, hidden_size] - cond_stage_model.text_model.end(); + cond_stage_model.begin(work_ctx, (int)tokens.size()); + auto result_pair = cond_stage_model.compute(n_threads, tokens); // [N, n_token, hidden_size] + struct ggml_tensor* hidden_states = result_pair.first; + struct ggml_tensor* pooled = result_pair.second; + // if (pooled != NULL) { + // print_ggml_tensor(hidden_states); + // print_ggml_tensor(pooled); + // } + + cond_stage_model.end(); int64_t t1 = ggml_time_ms(); LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); ggml_tensor* result = ggml_dup_tensor(work_ctx, hidden_states); @@ -5520,14 +5755,63 @@ class StableDiffusionGGML { float new_mean = ggml_tensor_mean(result); ggml_tensor_scale(result, (original_mean / new_mean)); } - return result; // [1, 77, 768] + if (force_zero_embeddings) { + float* vec = (float*)result->data; + for (int i = 0; i < ggml_nelements(result); i++) { + vec[i] = 0; + } + } + + ggml_tensor* vec = NULL; + if (version == VERSION_XL) { + size_t out_dim = 256; + vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, diffusion_model.adm_in_channels); + // [0:1280] + size_t offset = 0; + memcpy(vec->data, pooled->data, ggml_nbytes(pooled)); + offset += ggml_nbytes(pooled); + + struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 2); + // original_size_as_tuple + float orig_width = (float)width; + float orig_height = (float)height; + ggml_tensor_set_f32(timesteps, orig_height, 0); + ggml_tensor_set_f32(timesteps, orig_width, 1); + ggml_tensor* embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset); + offset += ggml_nbytes(embed_view); + set_timestep_embedding(timesteps, embed_view, out_dim); + // print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2)); + // crop_coords_top_left + float crop_coord_top = 0.f; + float crop_coord_left = 0.f; + ggml_tensor_set_f32(timesteps, crop_coord_top, 0); + ggml_tensor_set_f32(timesteps, crop_coord_left, 1); + embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset); + offset += ggml_nbytes(embed_view); + set_timestep_embedding(timesteps, embed_view, out_dim); + // print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2)); + // target_size_as_tuple + float target_width = (float)width; + float target_height = (float)height; + ggml_tensor_set_f32(timesteps, target_height, 0); + ggml_tensor_set_f32(timesteps, target_width, 1); + embed_view = ggml_view_2d(work_ctx, vec, out_dim, 2, ggml_type_size(GGML_TYPE_F32) * out_dim, offset); + offset += ggml_nbytes(embed_view); + set_timestep_embedding(timesteps, embed_view, out_dim); + // print_ggml_tensor(ggml_reshape_1d(work_ctx, embed_view, out_dim * 2)); + GGML_ASSERT(offset == ggml_nbytes(vec)); + } + // print_ggml_tensor(result); + return {result, vec}; } ggml_tensor* sample(ggml_context* work_ctx, ggml_tensor* x_t, ggml_tensor* noise, ggml_tensor* c, + ggml_tensor* c_vector, ggml_tensor* uc, + ggml_tensor* uc_vector, float cfg_scale, SampleMethod method, const std::vector& sigmas) { @@ -5540,7 +5824,7 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - diffusion_model.begin(noised_input, c, t_emb); + diffusion_model.begin(noised_input, c, t_emb, c_vector); bool has_unconditioned = cfg_scale != 1.0 && uc != NULL; @@ -5590,12 +5874,12 @@ class StableDiffusionGGML { ggml_tensor_scale(noised_input, c_in); // cond - diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, t_emb); + diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, t_emb, c_vector); float* negative_data = NULL; if (has_unconditioned) { // uncond - diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, t_emb); + diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, t_emb, uc_vector); negative_data = (float*)out_uncond->data; } float* vec_denoised = (float*)denoised->data; @@ -6181,10 +6465,10 @@ std::vector StableDiffusion::txt2img(std::string prompt, int64_t seed, int batch_count) { std::vector results; - if (width >= 1024 && height >= 1024) { // 1024 x 1024 images - LOG_WARN("Image too large, try a smaller size."); - return results; - } + // if (width >= 1024 && height >= 1024) { // 1024 x 1024 images + // LOG_WARN("Image too large, try a smaller size."); + // return results; + // } // extract and remove lora auto result_pair = extract_and_remove_lora(prompt); std::unordered_map lora_f2m = result_pair.first; // lora_name -> multiplier @@ -6201,11 +6485,12 @@ std::vector StableDiffusion::txt2img(std::string prompt, int64_t t1 = ggml_time_ms(); LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); struct ggml_init_params params; - params.mem_size = static_cast(2 * 1024 * 1024); // 2 MB + params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB params.mem_size += width * height * 3 * sizeof(float); params.mem_size *= batch_count; params.mem_buffer = NULL; params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); struct ggml_context* work_ctx = ggml_init(params); if (!work_ctx) { @@ -6221,17 +6506,26 @@ std::vector StableDiffusion::txt2img(std::string prompt, seed = rand(); } - t0 = ggml_time_ms(); - ggml_tensor* c = sd->get_learned_condition(work_ctx, prompt); - struct ggml_tensor* uc = NULL; + t0 = ggml_time_ms(); + auto cond_pair = sd->get_learned_condition(work_ctx, prompt, width, height); + ggml_tensor* c = cond_pair.first; + ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ] + struct ggml_tensor* uc = NULL; + struct ggml_tensor* uc_vector = NULL; if (cfg_scale != 1.0) { - uc = sd->get_learned_condition(work_ctx, negative_prompt); + bool force_zero_embeddings = false; + if (sd->version == VERSION_XL && negative_prompt.size() == 0) { + force_zero_embeddings = true; + } + auto uncond_pair = sd->get_learned_condition(work_ctx, negative_prompt, width, height, force_zero_embeddings); + uc = uncond_pair.first; + uc_vector = uncond_pair.second; // [adm_in_channels, ] } t1 = ggml_time_ms(); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t1 - t0); if (sd->free_params_immediately) { - sd->cond_stage_model.text_model.destroy(); + sd->cond_stage_model.destroy(); } std::vector final_latents; // collect latents to decode @@ -6250,7 +6544,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, std::vector sigmas = sd->denoiser->schedule->get_sigmas(sample_steps); - struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, uc, cfg_scale, sample_method, sigmas); + struct ggml_tensor* x_0 = sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, cfg_scale, sample_method, sigmas); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -6269,6 +6563,7 @@ std::vector StableDiffusion::txt2img(std::string prompt, for (size_t i = 0; i < final_latents.size(); i++) { t1 = ggml_time_ms(); struct ggml_tensor* img = sd->decode_first_stage(work_ctx, final_latents[i] /* x_0 */); + // print_ggml_tensor(img); if (img != NULL) { decoded_images.push_back(img); } @@ -6323,6 +6618,7 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, params.mem_size += width * height * 3 * sizeof(float) * 2; params.mem_buffer = NULL; params.no_alloc = false; + // LOG_DEBUG("mem_size %u ", params.mem_size); // draft context struct ggml_context* work_ctx = ggml_init(params); @@ -6366,15 +6662,24 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, t1 = ggml_time_ms(); LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - ggml_tensor* c = sd->get_learned_condition(work_ctx, prompt); - struct ggml_tensor* uc = NULL; + auto cond_pair = sd->get_learned_condition(work_ctx, prompt, width, height); + ggml_tensor* c = cond_pair.first; + ggml_tensor* c_vector = cond_pair.second; // [adm_in_channels, ] + struct ggml_tensor* uc = NULL; + struct ggml_tensor* uc_vector = NULL; if (cfg_scale != 1.0) { - uc = sd->get_learned_condition(work_ctx, negative_prompt); + bool force_zero_embeddings = false; + if (sd->version == VERSION_XL && negative_prompt.size() == 0) { + force_zero_embeddings = true; + } + auto uncond_pair = sd->get_learned_condition(work_ctx, negative_prompt, width, height, force_zero_embeddings); + uc = uncond_pair.first; + uc_vector = uncond_pair.second; // [adm_in_channels, ] } int64_t t2 = ggml_time_ms(); LOG_INFO("get_learned_condition completed, taking %" PRId64 " ms", t2 - t1); if (sd->free_params_immediately) { - sd->cond_stage_model.text_model.destroy(); + sd->cond_stage_model.destroy(); } // SDXL @@ -6386,7 +6691,7 @@ std::vector StableDiffusion::img2img(const uint8_t* init_img_data, ggml_tensor_set_f32_randn(noise, sd->rng); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - struct ggml_tensor* x_0 = sd->sample(work_ctx, init_latent, noise, c, uc, cfg_scale, sample_method, sigma_sched); + struct ggml_tensor* x_0 = sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, cfg_scale, sample_method, sigma_sched); // struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t t3 = ggml_time_ms();