diff --git a/.gitignore b/.gitignore index 31f89216..38fe570d 100644 --- a/.gitignore +++ b/.gitignore @@ -10,5 +10,4 @@ test/ *.gguf output*.png models* -!taesd-model.gguf *.log \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index d7522d5a..8545ef6e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,6 @@ option(SD_CUBLAS "sd: cuda backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) -option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) @@ -36,9 +35,6 @@ if(SD_CUBLAS) message("Use CUBLAS as backend stable-diffusion") set(GGML_CUBLAS ON) add_definitions(-DSD_USE_CUBLAS) - if(SD_FAST_SOFTMAX) - set(GGML_CUDA_FAST_SOFTMAX ON) - endif() endif() if(SD_METAL) diff --git a/README.md b/README.md index ff21c620..c1675aa1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) - Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN) - VAE tiling processing for reduce memory usage +- Control Net support with SD 1.5 - Sampling method - `Euler A` - `Euler` @@ -53,9 +54,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - [ ] More sampling methods - [ ] Make inference faster - The current implementation of ggml_conv_2d is slow and has high memory usage - - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] Implement Textual Inversion (embeddings) - [ ] Implement Inpainting support - [ ] k-quants support @@ -159,16 +158,20 @@ arguments: -m, --model [MODEL] path to model --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) + --control-net [CONTROL_PATH] path to control net model + --embd-dir [EMBEDDING_PATH] path to embeddings. --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0) If not specified, the default is the type of the weight file. --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img + --control-image [IMAGE] path to image condition, control net -o, --output OUTPUT path to write result image to (default: ./output.png) -p, --prompt [PROMPT] the prompt to render -n, --negative-prompt PROMPT the negative prompt (default: "") --cfg-scale SCALE unconditional guidance scale: (default: 7.0) --strength STRENGTH strength for noising/unnoising (default: 0.75) + --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) -W, --width W image width, in pixel space (default: 512) @@ -182,6 +185,7 @@ arguments: --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage + --control-net-cpu keep controlnet in cpu (for low vram) -v, --verbose print extra info ``` diff --git a/assets/control.png b/assets/control.png new file mode 100644 index 00000000..3ed95d09 Binary files /dev/null and b/assets/control.png differ diff --git a/assets/control_2.png b/assets/control_2.png new file mode 100644 index 00000000..9352dc0f Binary files /dev/null and b/assets/control_2.png differ diff --git a/assets/control_3.png b/assets/control_3.png new file mode 100644 index 00000000..4d114df0 Binary files /dev/null and b/assets/control_3.png differ diff --git a/clip.hpp b/clip.hpp index a456fffc..e0451099 100644 --- a/clip.hpp +++ b/clip.hpp @@ -2,6 +2,7 @@ #define __CLIP_HPP__ #include "ggml_extend.hpp" +#include "model.h" /*================================================== CLIPTokenizer ===================================================*/ @@ -67,6 +68,9 @@ std::vector> bytes_to_unicode() { } // Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py + +typedef std::function&)> on_new_token_cb_t; + class CLIPTokenizer { private: SDVersion version = VERSION_1_x; @@ -234,8 +238,11 @@ class CLIPTokenizer { return result; } - std::vector tokenize(std::string text, size_t max_length = 0, bool padding = false) { - std::vector tokens = encode(text); + std::vector tokenize(std::string text, + on_new_token_cb_t on_new_token_cb, + size_t max_length = 0, + bool padding = false) { + std::vector tokens = encode(text, on_new_token_cb); tokens.insert(tokens.begin(), BOS_TOKEN_ID); if (max_length > 0) { if (tokens.size() > max_length - 1) { @@ -255,7 +262,7 @@ class CLIPTokenizer { return tokens; } - std::vector encode(std::string text) { + std::vector encode(std::string text, on_new_token_cb_t on_new_token_cb) { std::string original_text = text; std::vector bpe_tokens; text = whitespace_clean(text); @@ -268,6 +275,10 @@ class CLIPTokenizer { std::string str = text; std::vector token_strs; while (std::regex_search(str, matches, pat)) { + bool skip = on_new_token_cb(str, bpe_tokens); + if (skip) { + continue; + } for (auto& token : matches) { std::string token_str = token.str(); std::u32string utf32_token; @@ -444,13 +455,13 @@ struct ResidualAttentionBlock { struct ggml_tensor* ln2_b; // [hidden_size, ] size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 4 * hidden_size * hidden_size * ggml_type_sizef(wtype); // q_w/k_w/v_w/out_w - mem_size += 8 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b - mem_size += 2 * hidden_size * intermediate_size * ggml_type_sizef(wtype); // fc1_w/fc2_w - mem_size += intermediate_size * ggml_type_sizef(GGML_TYPE_F32); // fc1_b - mem_size += hidden_size * ggml_type_sizef(GGML_TYPE_F32); // fc2_b - return static_cast(mem_size); + size_t mem_size = 0; + mem_size += 4 * ggml_row_size(wtype, hidden_size * hidden_size); // q_w/k_w/v_w/out_w + mem_size += 8 * ggml_row_size(GGML_TYPE_F32, hidden_size); // q_b/k_b/v_b/out_b/ln1_w/ln1_b/ln2_w/ln2_b + mem_size += 2 * ggml_row_size(wtype, hidden_size * intermediate_size); // fc1_w/fc2_w + mem_size += ggml_row_size(GGML_TYPE_F32, intermediate_size); // fc1_b + mem_size += ggml_row_size(GGML_TYPE_F32, hidden_size); // fc2_b + return mem_size; } void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { @@ -597,6 +608,7 @@ struct CLIPTextModel { struct ggml_tensor* position_ids; struct ggml_tensor* token_embed_weight; struct ggml_tensor* position_embed_weight; + struct ggml_tensor* token_embed_custom; // transformer std::vector resblocks; @@ -604,6 +616,9 @@ struct CLIPTextModel { struct ggml_tensor* final_ln_b; struct ggml_tensor* text_projection; + std::string embd_dir; + int32_t num_custom_embeddings = 0; + std::vector readed_embeddings; CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, int clip_skip = -1, @@ -642,18 +657,21 @@ struct CLIPTextModel { } size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids - mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight - mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // position_embed_weight + size_t mem_size = 0; + mem_size += ggml_row_size(GGML_TYPE_I32, hidden_size * max_position_embeddings); // position_ids + mem_size += ggml_row_size(wtype, hidden_size * vocab_size); // token_embed_weight + mem_size += ggml_row_size(wtype, hidden_size * max_position_embeddings); // position_embed_weight + if(version == OPENAI_CLIP_VIT_L_14) { + mem_size += ggml_row_size(wtype, hidden_size * max_position_embeddings); // token_embed_custom + } for (int i = 0; i < num_hidden_layers; i++) { mem_size += resblocks[i].calculate_mem_size(wtype); } - mem_size += 2 * hidden_size * ggml_type_sizef(GGML_TYPE_F32); // final_ln_w/b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, hidden_size); // final_ln_w/b if (version == OPEN_CLIP_VIT_BIGG_14) { - mem_size += hidden_size * projection_dim * ggml_type_sizef(GGML_TYPE_F32); // text_projection + mem_size += ggml_row_size(GGML_TYPE_F32, hidden_size * projection_dim); // text_projection } - return static_cast(mem_size); + return mem_size; } void map_by_name(std::map& tensors, const std::string prefix) { @@ -670,14 +688,48 @@ struct CLIPTextModel { } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, size_t max_token_idx = 0, bool return_pooled = false) { + bool load_embedding(std::string embd_name, std::string embd_path, std::vector &bpe_tokens) { + // the order matters + ModelLoader model_loader; + if(!model_loader.init_from_file(embd_path)) { + LOG_ERROR("embedding '%s' failed", embd_name.c_str()); + return false; + } + struct ggml_init_params params; + params.mem_size = 32 * 1024; // max for custom embeddings 32 KB + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* embd_ctx = ggml_init(params); + struct ggml_tensor* embd = NULL; + auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) { + if(tensor_storage.ne[0] != hidden_size) { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size); + return false; + } + embd = ggml_new_tensor_2d(embd_ctx, token_embed_weight->type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd; + return true; + }; + model_loader.load_tensors(on_load, NULL); + ggml_backend_tensor_set(token_embed_custom, embd->data, num_custom_embeddings * hidden_size * ggml_type_size(token_embed_custom->type), ggml_nbytes(embd)); + readed_embeddings.push_back(embd_name); + for(int i = 0; i < embd->ne[1]; i++) { + bpe_tokens.push_back(vocab_size + num_custom_embeddings); + // LOG_DEBUG("new custom token: %i", vocab_size + num_custom_embeddings); + num_custom_embeddings++; + } + LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings); + return true; + } + + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, uint32_t max_token_idx = 0, bool return_pooled = false) { // input_ids: [N, n_token] GGML_ASSERT(input_ids->ne[0] <= position_ids->ne[0]); // token_embedding + position_embedding struct ggml_tensor* x; x = ggml_add(ctx0, - ggml_get_rows(ctx0, token_embed_weight, input_ids), + ggml_get_rows(ctx0, tkn_embeddings == NULL ? token_embed_weight : tkn_embeddings, input_ids), ggml_get_rows(ctx0, position_embed_weight, ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size] @@ -723,6 +775,10 @@ struct CLIPTextModel { final_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + if(version == OPENAI_CLIP_VIT_L_14) { + token_embed_custom = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings); + } + if (version == OPEN_CLIP_VIT_BIGG_14) { text_projection = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); } @@ -805,11 +861,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, size_t max_token_idx = 0, bool return_pooled = false) { + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, struct ggml_tensor* embeddings, size_t max_token_idx = 0, bool return_pooled = false) { if (return_pooled) { - return text_model2.forward(ctx0, input_ids2, max_token_idx, return_pooled); + return text_model2.forward(ctx0, input_ids2, NULL, max_token_idx, return_pooled); } - auto hidden_states = text_model.forward(ctx0, input_ids); // [N, n_token, hidden_size] + auto hidden_states = text_model.forward(ctx0, input_ids, embeddings); // [N, n_token, hidden_size] // LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); if (version == VERSION_XL) { hidden_states = ggml_reshape_4d(ctx0, @@ -820,7 +876,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { hidden_states->ne[3]); hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 2, 0, 1, 3)); - auto hidden_states2 = text_model2.forward(ctx0, input_ids2); // [N, n_token, hidden_size2] + auto hidden_states2 = text_model2.forward(ctx0, input_ids2, NULL); // [N, n_token, hidden_size2] hidden_states2 = ggml_reshape_4d(ctx0, hidden_states2, hidden_states2->ne[0], @@ -857,12 +913,36 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); } + auto on_new_token_cb = [&] (std::string& str, std::vector &bpe_tokens) -> bool { + size_t word_end = str.find(","); + std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); + embd_name = trim(embd_name); + std::string embd_path = get_full_path(text_model.embd_dir, embd_name + ".pt"); + if(embd_path.size() == 0) { + embd_path = get_full_path(text_model.embd_dir, embd_name + ".ckpt"); + } + if(embd_path.size() == 0) { + embd_path = get_full_path(text_model.embd_dir, embd_name + ".safetensors"); + } + if(embd_path.size() > 0) { + if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) { + if(word_end != std::string::npos) { + str = str.substr(word_end); + } else { + str = ""; + } + return true; + } + } + return false; + }; + std::vector tokens; std::vector weights; for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer.encode(curr_text); + std::vector curr_tokens = tokenizer.encode(curr_text, on_new_token_cb); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } @@ -951,7 +1031,26 @@ struct FrozenCLIPEmbedderWithCustomWords : public GGMLModule { } } - struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, max_token_idx, return_pooled); + struct ggml_tensor* embeddings = NULL; + + if(text_model.num_custom_embeddings > 0 && version != VERSION_XL) { + embeddings = ggml_new_tensor_2d(ctx0, wtype, text_model.hidden_size, text_model.vocab_size + text_model.num_custom_embeddings /* custom placeholder */); + ggml_allocr_alloc(allocr, embeddings); + if (!ggml_allocr_is_measure(allocr)) { + // really bad, there is memory inflexibility (this is for host<->device memory conflicts) + void* freeze_data = malloc(ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_weight, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_set(embeddings, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + free(freeze_data); + // concatenate custom embeddings + void* custom_data = malloc(ggml_nbytes(text_model.token_embed_custom)); + ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_custom, custom_data, 0, ggml_nbytes(text_model.token_embed_custom)); + ggml_backend_tensor_set(embeddings, custom_data, ggml_nbytes(text_model.token_embed_weight), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)); + free(custom_data); + } + } + + struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, embeddings, max_token_idx, return_pooled); ggml_build_forward_expand(gf, hidden_states); ggml_free(ctx0); diff --git a/common.hpp b/common.hpp index 458f8be4..a71e4d37 100644 --- a/common.hpp +++ b/common.hpp @@ -15,10 +15,10 @@ struct DownSample { bool vae_downsample = false; size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // op_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // op_b - return static_cast(mem_size); + size_t mem_size = 0; + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // op_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // op_b + return mem_size; } void init_params(struct ggml_context* ctx, ggml_type wtype) { @@ -59,10 +59,10 @@ struct UpSample { struct ggml_tensor* conv_b; // [out_channels,] size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // op_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // op_b - return static_cast(mem_size); + size_t mem_size = 0; + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // op_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // op_b + return mem_size; } void init_params(struct ggml_context* ctx, ggml_type wtype) { @@ -83,4 +83,461 @@ struct UpSample { } }; +struct ResBlock { + // network hparams + int channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4) + int emb_channels; // time_embed_dim + int out_channels; // mult * model_channels + + // network params + // in_layers + struct ggml_tensor* in_layer_0_w; // [channels, ] + struct ggml_tensor* in_layer_0_b; // [channels, ] + // in_layer_1 is nn.SILU() + struct ggml_tensor* in_layer_2_w; // [out_channels, channels, 3, 3] + struct ggml_tensor* in_layer_2_b; // [out_channels, ] + + // emb_layers + // emb_layer_0 is nn.SILU() + struct ggml_tensor* emb_layer_1_w; // [out_channels, emb_channels] + struct ggml_tensor* emb_layer_1_b; // [out_channels, ] + + // out_layers + struct ggml_tensor* out_layer_0_w; // [out_channels, ] + struct ggml_tensor* out_layer_0_b; // [out_channels, ] + // out_layer_1 is nn.SILU() + // out_layer_2 is nn.Dropout(), p = 0 for inference + struct ggml_tensor* out_layer_3_w; // [out_channels, out_channels, 3, 3] + struct ggml_tensor* out_layer_3_b; // [out_channels, ] + + // skip connection, only if out_channels != channels + struct ggml_tensor* skip_w; // [out_channels, channels, 1, 1] + struct ggml_tensor* skip_b; // [out_channels, ] + + size_t calculate_mem_size(ggml_type wtype) { + size_t mem_size = 0; + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, channels); // in_layer_0_w/b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 3 * 3); // in_layer_2_w + mem_size += 5 * ggml_row_size(GGML_TYPE_F32, out_channels); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b + mem_size += ggml_row_size(wtype, out_channels * emb_channels); // emb_layer_1_w + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * out_channels * 3 * 3); // out_layer_3_w + + if (out_channels != channels) { + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * channels * 1 * 1); // skip_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // skip_b + } + return mem_size; + } + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + in_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + in_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + in_layer_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels); + in_layer_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + emb_layer_1_w = ggml_new_tensor_2d(ctx, wtype, emb_channels, out_channels); + emb_layer_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + out_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + out_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + out_layer_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels); + out_layer_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + + if (out_channels != channels) { + skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels, out_channels); + skip_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + } + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "in_layers.0.weight"] = in_layer_0_w; + tensors[prefix + "in_layers.0.bias"] = in_layer_0_b; + tensors[prefix + "in_layers.2.weight"] = in_layer_2_w; + tensors[prefix + "in_layers.2.bias"] = in_layer_2_b; + + tensors[prefix + "emb_layers.1.weight"] = emb_layer_1_w; + tensors[prefix + "emb_layers.1.bias"] = emb_layer_1_b; + + tensors[prefix + "out_layers.0.weight"] = out_layer_0_w; + tensors[prefix + "out_layers.0.bias"] = out_layer_0_b; + tensors[prefix + "out_layers.3.weight"] = out_layer_3_w; + tensors[prefix + "out_layers.3.bias"] = out_layer_3_b; + + if (out_channels != channels) { + tensors[prefix + "skip_connection.weight"] = skip_w; + tensors[prefix + "skip_connection.bias"] = skip_b; + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb) { + // x: [N, channels, h, w] + // emb: [N, emb_channels] + + // in_layers + auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b); + h = ggml_silu_inplace(ctx, h); + h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w] + + // emb_layers + auto emb_out = ggml_silu(ctx, emb); + emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels] + emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] + + // out_layers + h = ggml_add(ctx, h, emb_out); + h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b); + h = ggml_silu_inplace(ctx, h); + + // dropout, skip for inference + + h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w] + + // skip connection + if (out_channels != channels) { + x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w] + } + + h = ggml_add(ctx, h, x); + return h; // [N, out_channels, h, w] + } +}; + +struct SpatialTransformer { + int in_channels; // mult * model_channels + int n_head; // num_heads + int d_head; // in_channels // n_heads + int depth = 1; // 1 + int context_dim = 768; // hidden_size, 1024 for VERSION_2_x + + // group norm + struct ggml_tensor* norm_w; // [in_channels,] + struct ggml_tensor* norm_b; // [in_channels,] + + // proj_in + struct ggml_tensor* proj_in_w; // [in_channels, in_channels, 1, 1] + struct ggml_tensor* proj_in_b; // [in_channels,] + + // transformer + struct Transformer { + // layer norm 1 + struct ggml_tensor* norm1_w; // [in_channels, ] + struct ggml_tensor* norm1_b; // [in_channels, ] + + // attn1 + struct ggml_tensor* attn1_q_w; // [in_channels, in_channels] + struct ggml_tensor* attn1_k_w; // [in_channels, in_channels] + struct ggml_tensor* attn1_v_w; // [in_channels, in_channels] + + struct ggml_tensor* attn1_out_w; // [in_channels, in_channels] + struct ggml_tensor* attn1_out_b; // [in_channels, ] + + // layer norm 2 + struct ggml_tensor* norm2_w; // [in_channels, ] + struct ggml_tensor* norm2_b; // [in_channels, ] + + // attn2 + struct ggml_tensor* attn2_q_w; // [in_channels, in_channels] + struct ggml_tensor* attn2_k_w; // [in_channels, context_dim] + struct ggml_tensor* attn2_v_w; // [in_channels, context_dim] + + struct ggml_tensor* attn2_out_w; // [in_channels, in_channels] + struct ggml_tensor* attn2_out_b; // [in_channels, ] + + // layer norm 3 + struct ggml_tensor* norm3_w; // [in_channels, ] + struct ggml_tensor* norm3_b; // [in_channels, ] + + // ff + struct ggml_tensor* ff_0_proj_w; // [in_channels * 4 * 2, in_channels] + struct ggml_tensor* ff_0_proj_b; // [in_channels * 4 * 2] + + struct ggml_tensor* ff_2_w; // [in_channels, in_channels * 4] + struct ggml_tensor* ff_2_b; // [in_channels,] + }; + + std::vector transformers; + + // proj_out + struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1] + struct ggml_tensor* proj_out_b; // [in_channels,] + + SpatialTransformer(int depth = 1) + : depth(depth) { + transformers.resize(depth); + } + + int get_num_tensors() { + return depth * 20 + 7; + } + + size_t calculate_mem_size(ggml_type wtype) { + size_t mem_size = 0; + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm_w/norm_b + mem_size += 2 * ggml_row_size(GGML_TYPE_F16, in_channels * in_channels * 1 * 1); // proj_in_w/proj_out_w + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // proj_in_b/proj_out_b + + // transformer + for (auto& transformer : transformers) { + mem_size += 6 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm1-3_w/b + mem_size += 6 * ggml_row_size(wtype, in_channels * in_channels); // attn1_q/k/v/out_w attn2_q/out_w + mem_size += 2 * ggml_row_size(wtype, in_channels * context_dim); // attn2_k/v_w + mem_size += ggml_row_size(wtype, in_channels * 4 * 2 * in_channels ); // ff_0_proj_w + mem_size += ggml_row_size(GGML_TYPE_F32, in_channels * 4 * 2); // ff_0_proj_b + mem_size += ggml_row_size(wtype, in_channels * 4 * in_channels); // ff_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, in_channels); // ff_2_b + } + return mem_size; + } + + void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { + norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + proj_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); + proj_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + proj_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); + proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + // transformer + for (auto& transformer : transformers) { + transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + + transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); + transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); + + transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); + transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + + transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2); + transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2); + + transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels); + transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); + } + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "norm.weight"] = norm_w; + tensors[prefix + "norm.bias"] = norm_b; + tensors[prefix + "proj_in.weight"] = proj_in_w; + tensors[prefix + "proj_in.bias"] = proj_in_b; + + // transformer + for (int i = 0; i < transformers.size(); i++) { + auto& transformer = transformers[i]; + std::string transformer_prefix = prefix + "transformer_blocks." + std::to_string(i) + "."; + tensors[transformer_prefix + "attn1.to_q.weight"] = transformer.attn1_q_w; + tensors[transformer_prefix + "attn1.to_k.weight"] = transformer.attn1_k_w; + tensors[transformer_prefix + "attn1.to_v.weight"] = transformer.attn1_v_w; + + tensors[transformer_prefix + "attn1.to_out.0.weight"] = transformer.attn1_out_w; + tensors[transformer_prefix + "attn1.to_out.0.bias"] = transformer.attn1_out_b; + + tensors[transformer_prefix + "ff.net.0.proj.weight"] = transformer.ff_0_proj_w; + tensors[transformer_prefix + "ff.net.0.proj.bias"] = transformer.ff_0_proj_b; + tensors[transformer_prefix + "ff.net.2.weight"] = transformer.ff_2_w; + tensors[transformer_prefix + "ff.net.2.bias"] = transformer.ff_2_b; + + tensors[transformer_prefix + "attn2.to_q.weight"] = transformer.attn2_q_w; + tensors[transformer_prefix + "attn2.to_k.weight"] = transformer.attn2_k_w; + tensors[transformer_prefix + "attn2.to_v.weight"] = transformer.attn2_v_w; + + tensors[transformer_prefix + "attn2.to_out.0.weight"] = transformer.attn2_out_w; + tensors[transformer_prefix + "attn2.to_out.0.bias"] = transformer.attn2_out_b; + + tensors[transformer_prefix + "norm1.weight"] = transformer.norm1_w; + tensors[transformer_prefix + "norm1.bias"] = transformer.norm1_b; + tensors[transformer_prefix + "norm2.weight"] = transformer.norm2_w; + tensors[transformer_prefix + "norm2.bias"] = transformer.norm2_b; + tensors[transformer_prefix + "norm3.weight"] = transformer.norm3_w; + tensors[transformer_prefix + "norm3.bias"] = transformer.norm3_b; + } + + tensors[prefix + "proj_out.weight"] = proj_out_w; + tensors[prefix + "proj_out.bias"] = proj_out_b; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { + // x: [N, in_channels, h, w] + // context: [N, max_position, hidden_size(aka context_dim)] + auto x_in = x; + x = ggml_nn_group_norm(ctx, x, norm_w, norm_b); + // proj_in + x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w] + + // transformer + const int64_t n = x->ne[3]; + const int64_t c = x->ne[2]; + const int64_t h = x->ne[1]; + const int64_t w = x->ne[0]; + const int64_t max_position = context->ne[1]; + x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, in_channels] + + for (auto& transformer : transformers) { + auto r = x; + // layer norm 1 + x = ggml_reshape_2d(ctx, x, c, w * h * n); + x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b); + + // self-attention + { + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels] +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); +#endif + q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] + q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] + + struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn1_k_w, x); // [N * h * w, in_channels] + k = ggml_reshape_4d(ctx, k, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] + k = ggml_reshape_3d(ctx, k, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] + + struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn1_v_w, x); // [N * h * w, in_channels] + v = ggml_reshape_4d(ctx, v, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w] + v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w] + +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) + struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] +#else + struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w] + // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + kq = ggml_soft_max_inplace(ctx, kq); + + struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] +#endif + kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); + kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, h * w, n_head, d_head] + + // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); + x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); + + x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b); + + x = ggml_reshape_4d(ctx, x, c, w, h, n); + } + + x = ggml_add(ctx, x, r); + r = x; + + // layer norm 2 + x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b); + + // cross-attention + { + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size] + struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels] +#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) + q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); +#endif + q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] + q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] + q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] + + struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn2_k_w, context); // [N * max_position, in_channels] + k = ggml_reshape_4d(ctx, k, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] + k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, max_position, d_head] + k = ggml_reshape_3d(ctx, k, d_head, max_position, n_head * n); // [N * n_head, max_position, d_head] + + struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn2_v_w, context); // [N * max_position, in_channels] + v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] + v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position] + v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position] +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) + struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] +#else + struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] + // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); + kq = ggml_soft_max_inplace(ctx, kq); + + struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] +#endif + kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); + kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); + + // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels] + x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels] + + x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b); + + x = ggml_reshape_4d(ctx, x, c, w, h, n); + } + + x = ggml_add(ctx, x, r); + r = x; + + // layer norm 3 + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b); + + // ff + { + // GEGLU + auto x_w = ggml_view_2d(ctx, + transformer.ff_0_proj_w, + transformer.ff_0_proj_w->ne[0], + transformer.ff_0_proj_w->ne[1] / 2, + transformer.ff_0_proj_w->nb[1], + 0); // [in_channels * 4, in_channels] + auto x_b = ggml_view_1d(ctx, + transformer.ff_0_proj_b, + transformer.ff_0_proj_b->ne[0] / 2, + 0); // [in_channels * 4, in_channels] + auto gate_w = ggml_view_2d(ctx, + transformer.ff_0_proj_w, + transformer.ff_0_proj_w->ne[0], + transformer.ff_0_proj_w->ne[1] / 2, + transformer.ff_0_proj_w->nb[1], + transformer.ff_0_proj_w->nb[1] * transformer.ff_0_proj_w->ne[1] / 2); // [in_channels * 4, ] + auto gate_b = ggml_view_1d(ctx, + transformer.ff_0_proj_b, + transformer.ff_0_proj_b->ne[0] / 2, + transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ] + x = ggml_reshape_2d(ctx, x, c, w * h * n); + auto x_in = x; + x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4] + auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4] + + gate = ggml_gelu_inplace(ctx, gate); + + x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4] + // fc + x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels] + } + + x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels] + + // residual + x = ggml_add(ctx, x, r); + } + + x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w] + + // proj_out + x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w] + + x = ggml_add(ctx, x, x_in); + return x; + } +}; + #endif // __COMMON_HPP__ \ No newline at end of file diff --git a/control.hpp b/control.hpp new file mode 100644 index 00000000..543998f4 --- /dev/null +++ b/control.hpp @@ -0,0 +1,695 @@ +#ifndef __CONTROL_HPP__ +#define __CONTROL_HPP__ + +#include "ggml_extend.hpp" +#include "common.hpp" +#include "model.h" + +#define CONTROL_NET_GRAPH_SIZE 1536 + +/* + =================================== ControlNet =================================== + Reference: https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/cldm/cldm.py + +*/ + +struct CNHintBlock { + int hint_channels = 3; + int model_channels = 320; // SD 1.5 + int feat_channels[4] = { 16, 32, 96, 256 }; + int num_blocks = 3; + ggml_tensor* conv_first_w; // [feat_channels[0], hint_channels, 3, 3] + ggml_tensor* conv_first_b; // [feat_channels[0]] + + struct hint_block { + ggml_tensor* conv_0_w; // [feat_channels[idx], feat_channels[idx], 3, 3] + ggml_tensor* conv_0_b; // [feat_channels[idx]] + + ggml_tensor* conv_1_w; // [feat_channels[idx + 1], feat_channels[idx], 3, 3] + ggml_tensor* conv_1_b; // [feat_channels[idx + 1]] + }; + + hint_block blocks[3]; + ggml_tensor* conv_final_w; // [model_channels, feat_channels[3], 3, 3] + ggml_tensor* conv_final_b; // [model_channels] + + size_t calculate_mem_size() { + size_t mem_size = feat_channels[0] * hint_channels * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_first_w + mem_size += feat_channels[0] * ggml_type_size(GGML_TYPE_F32); // conv_first_b + for (int i = 0; i < num_blocks; i++) { + mem_size += feat_channels[i] * feat_channels[i] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_0_w + mem_size += feat_channels[i] * ggml_type_size(GGML_TYPE_F32); // conv_0_b + mem_size += feat_channels[i + 1] * feat_channels[i] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_1_w + mem_size += feat_channels[i + 1] * ggml_type_size(GGML_TYPE_F32); // conv_1_b + } + mem_size += model_channels * feat_channels[3] * 3 * 3 * ggml_type_size(GGML_TYPE_F16); // conv_final_w + mem_size += model_channels * ggml_type_size(GGML_TYPE_F32); // conv_final_b + return mem_size; + } + + void init_params(struct ggml_context* ctx) { + conv_first_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, hint_channels, feat_channels[0]); + conv_first_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[0]); + + for (int i = 0; i < num_blocks; i++) { + blocks[i].conv_0_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[i], feat_channels[i]); + blocks[i].conv_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[i]); + blocks[i].conv_1_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[i], feat_channels[i + 1]); + blocks[i].conv_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, feat_channels[i + 1]); + } + + conv_final_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, feat_channels[3], model_channels); + conv_final_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, model_channels); + } + + void map_by_name(std::map& tensors, const std::string prefix) { + tensors[prefix + "input_hint_block.0.weight"] = conv_first_w; + tensors[prefix + "input_hint_block.0.bias"] = conv_first_b; + int index = 2; + for (int i = 0; i < num_blocks; i++) { + tensors[prefix + "input_hint_block." + std::to_string(index) +".weight"] = blocks[i].conv_0_w; + tensors[prefix + "input_hint_block." + std::to_string(index) +".bias"] = blocks[i].conv_0_b; + index += 2; + tensors[prefix + "input_hint_block." + std::to_string(index) +".weight"] = blocks[i].conv_1_w; + tensors[prefix + "input_hint_block." + std::to_string(index) +".bias"] = blocks[i].conv_1_b; + index += 2; + } + tensors[prefix + "input_hint_block.14.weight"] = conv_final_w; + tensors[prefix + "input_hint_block.14.bias"] = conv_final_b; + } + + struct ggml_tensor* forward(ggml_context* ctx, struct ggml_tensor* x) { + auto h = ggml_nn_conv_2d(ctx, x, conv_first_w, conv_first_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + + auto body_h = h; + for(int i = 0; i < num_blocks; i++) { + // operations.conv_nd(dims, 16, 16, 3, padding=1) + body_h = ggml_nn_conv_2d(ctx, body_h, blocks[i].conv_0_w, blocks[i].conv_0_b, 1, 1, 1, 1); + body_h = ggml_silu_inplace(ctx, body_h); + // operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2) + body_h = ggml_nn_conv_2d(ctx, body_h, blocks[i].conv_1_w, blocks[i].conv_1_b, 2, 2, 1, 1); + body_h = ggml_silu_inplace(ctx, body_h); + } + + h = ggml_nn_conv_2d(ctx, body_h, conv_final_w, conv_final_b, 1, 1, 1, 1); + h = ggml_silu_inplace(ctx, h); + return h; + } +}; + +struct CNZeroConv { + int channels; + ggml_tensor* conv_w; // [channels, channels, 1, 1] + ggml_tensor* conv_b; // [channels] + + void init_params(struct ggml_context* ctx) { + conv_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels,channels); + conv_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); + } +}; + +struct ControlNet : public GGMLModule { + int in_channels = 4; + int model_channels = 320; + int out_channels = 4; + int num_res_blocks = 2; + std::vector attention_resolutions = {4, 2, 1}; + std::vector channel_mult = {1, 2, 4, 4}; + std::vector transformer_depth = {1, 1, 1, 1}; + int time_embed_dim = 1280; // model_channels*4 + int num_heads = 8; + int num_head_channels = -1; // channels // num_heads + int context_dim = 768; + int middle_out_channel; + CNHintBlock input_hint_block; + CNZeroConv zero_convs[12]; + int num_zero_convs = 1; + + // network params + struct ggml_tensor* time_embed_0_w; // [time_embed_dim, model_channels] + struct ggml_tensor* time_embed_0_b; // [time_embed_dim, ] + // time_embed_1 is nn.SILU() + struct ggml_tensor* time_embed_2_w; // [time_embed_dim, time_embed_dim] + struct ggml_tensor* time_embed_2_b; // [time_embed_dim, ] + + struct ggml_tensor* input_block_0_w; // [model_channels, in_channels, 3, 3] + struct ggml_tensor* input_block_0_b; // [model_channels, ] + + // input_blocks + ResBlock input_res_blocks[4][2]; + SpatialTransformer input_transformers[3][2]; + DownSample input_down_samples[3]; + + // middle_block + ResBlock middle_block_0; + SpatialTransformer middle_block_1; + ResBlock middle_block_2; + + struct ggml_tensor* middle_block_out_w; // [middle_out_channel, middle_out_channel, 1, 1] + struct ggml_tensor* middle_block_out_b; // [middle_out_channel, ] + ggml_backend_buffer_t control_buffer = NULL; // keep control output tensors in backend memory + ggml_context* control_ctx = NULL; + std::vector controls; // (12 input block outputs, 1 middle block output) SD 1.5 + + ControlNet() { + name = "controlnet"; + // input_blocks + std::vector input_block_chans; + input_block_chans.push_back(model_channels); + int ch = model_channels; + zero_convs[0].channels = model_channels; + int ds = 1; + + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + int mult = channel_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + input_res_blocks[i][j].channels = ch; + input_res_blocks[i][j].emb_channels = time_embed_dim; + input_res_blocks[i][j].out_channels = mult * model_channels; + + ch = mult * model_channels; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } + input_transformers[i][j] = SpatialTransformer(transformer_depth[i]); + input_transformers[i][j].in_channels = ch; + input_transformers[i][j].n_head = n_head; + input_transformers[i][j].d_head = d_head; + input_transformers[i][j].context_dim = context_dim; + } + input_block_chans.push_back(ch); + + zero_convs[num_zero_convs].channels = ch; + num_zero_convs++; + } + if (i != len_mults - 1) { + input_down_samples[i].channels = ch; + input_down_samples[i].out_channels = ch; + input_block_chans.push_back(ch); + + zero_convs[num_zero_convs].channels = ch; + num_zero_convs++; + ds *= 2; + } + } + GGML_ASSERT(num_zero_convs == 12); + + // middle blocks + middle_block_0.channels = ch; + middle_block_0.emb_channels = time_embed_dim; + middle_block_0.out_channels = ch; + + int n_head = num_heads; + int d_head = ch / num_heads; + if (num_head_channels != -1) { + d_head = num_head_channels; + n_head = ch / d_head; + } + middle_block_1 = SpatialTransformer(transformer_depth[transformer_depth.size() - 1]); + middle_block_1.in_channels = ch; + middle_block_1.n_head = n_head; + middle_block_1.d_head = d_head; + middle_block_1.context_dim = context_dim; + + middle_block_2.channels = ch; + middle_block_2.emb_channels = time_embed_dim; + middle_block_2.out_channels = ch; + middle_out_channel = ch; + } + + size_t calculate_mem_size() { + size_t mem_size = 0; + mem_size += input_hint_block.calculate_mem_size(); + mem_size += ggml_row_size(wtype, time_embed_dim * model_channels); // time_embed_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // time_embed_0_b + mem_size += ggml_row_size(wtype, time_embed_dim * time_embed_dim); // time_embed_2_w + mem_size += ggml_row_size(GGML_TYPE_F32,time_embed_dim); // time_embed_2_b + + mem_size += ggml_row_size(GGML_TYPE_F16, model_channels * in_channels * 3 * 3); // input_block_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, model_channels); // input_block_0_b + + // input_blocks + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + mem_size += input_res_blocks[i][j].calculate_mem_size(wtype); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + mem_size += input_transformers[i][j].calculate_mem_size(wtype); + } + } + if (i != len_mults - 1) { + ds *= 2; + mem_size += input_down_samples[i].calculate_mem_size(wtype); + } + } + + for (int i = 0; i < num_zero_convs; i++) { + mem_size += ggml_row_size(GGML_TYPE_F16, zero_convs[i].channels * zero_convs[i].channels); + mem_size += ggml_row_size(GGML_TYPE_F32, zero_convs[i].channels); + } + + // middle_block + mem_size += middle_block_0.calculate_mem_size(wtype); + mem_size += middle_block_1.calculate_mem_size(wtype); + mem_size += middle_block_2.calculate_mem_size(wtype); + + mem_size += ggml_row_size(GGML_TYPE_F16, middle_out_channel * middle_out_channel); // middle_block_out_w + mem_size += ggml_row_size(GGML_TYPE_F32, middle_out_channel); // middle_block_out_b + + return mem_size; + } + + size_t get_num_tensors() { + // in + size_t num_tensors = 6; + + num_tensors += num_zero_convs * 2; + + // input blocks + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + num_tensors += 12; + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + num_tensors += input_transformers[i][j].get_num_tensors(); + } + } + if (i != len_mults - 1) { + ds *= 2; + num_tensors += 2; + } + } + + // middle blocks + num_tensors += 13 * 2; + num_tensors += middle_block_1.get_num_tensors(); + return num_tensors; + } + + void init_params() { + ggml_allocr* alloc = ggml_allocr_new_from_buffer(params_buffer); + + input_hint_block.init_params(params_ctx); + + time_embed_0_w = ggml_new_tensor_2d(params_ctx, wtype, model_channels, time_embed_dim); + time_embed_0_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, time_embed_dim); + time_embed_2_w = ggml_new_tensor_2d(params_ctx, wtype, time_embed_dim, time_embed_dim); + time_embed_2_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, time_embed_dim); + + // input_blocks + input_block_0_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 3, 3, in_channels, model_channels); + input_block_0_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, model_channels); + + int ds = 1; + int len_mults = channel_mult.size(); + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + input_res_blocks[i][j].init_params(params_ctx, wtype); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + input_transformers[i][j].init_params(params_ctx, alloc, wtype); + } + } + if (i != len_mults - 1) { + input_down_samples[i].init_params(params_ctx, wtype); + ds *= 2; + } + } + + for (int i = 0; i < num_zero_convs; i++) { + zero_convs[i].init_params(params_ctx); + } + + // middle_blocks + middle_block_0.init_params(params_ctx, wtype); + middle_block_1.init_params(params_ctx, alloc, wtype); + middle_block_2.init_params(params_ctx, wtype); + + // middle_block_out + middle_block_out_w = ggml_new_tensor_4d(params_ctx, GGML_TYPE_F16, 1, 1, middle_out_channel, middle_out_channel); + middle_block_out_b = ggml_new_tensor_1d(params_ctx, GGML_TYPE_F32, middle_out_channel); + + // alloc all tensors linked to this context + for (struct ggml_tensor* t = ggml_get_first_tensor(params_ctx); t != NULL; t = ggml_get_next_tensor(params_ctx, t)) { + if (t->data == NULL) { + ggml_allocr_alloc(alloc, t); + } + } + + ggml_allocr_free(alloc); + } + + bool load_from_file(const std::string& file_path, ggml_backend_t backend_, ggml_type wtype_) { + LOG_INFO("loading control net from '%s'", file_path.c_str()); + + std::map control_tensors; + + ModelLoader model_loader; + if (!model_loader.init_from_file(file_path)) { + LOG_ERROR("init control net model loader from file failed: '%s'", file_path.c_str()); + return false; + } + + if (!alloc_params_buffer(backend_, wtype_)) { + return false; + } + + // prepare memory for the weights + { + init_params(); + map_by_name(control_tensors, ""); + } + + std::set tensor_names_in_file; + + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + tensor_names_in_file.insert(name); + + struct ggml_tensor* real; + if (control_tensors.find(name) != control_tensors.end()) { + real = control_tensors[name]; + } else { + LOG_ERROR("unknown tensor '%s' in model file", name.data()); + return true; + } + + if ( + real->ne[0] != tensor_storage.ne[0] || + real->ne[1] != tensor_storage.ne[1] || + real->ne[2] != tensor_storage.ne[2] || + real->ne[3] != tensor_storage.ne[3]) { + LOG_ERROR( + "tensor '%s' has wrong shape in model file: " + "got [%d, %d, %d, %d], expected [%d, %d, %d, %d]", + name.c_str(), + (int)tensor_storage.ne[0], (int)tensor_storage.ne[1], (int)tensor_storage.ne[2], (int)tensor_storage.ne[3], + (int)real->ne[0], (int)real->ne[1], (int)real->ne[2], (int)real->ne[3]); + return false; + } + + *dst_tensor = real; + + return true; + }; + + bool success = model_loader.load_tensors(on_new_tensor_cb, backend); + + bool some_tensor_not_init = false; + + for (auto pair : control_tensors) { + if (tensor_names_in_file.find(pair.first) == tensor_names_in_file.end()) { + LOG_ERROR("tensor '%s' not in model file", pair.first.c_str()); + some_tensor_not_init = true; + } + } + + if (some_tensor_not_init) { + return false; + } + + LOG_INFO("control net model loaded"); + return success; + } + + void map_by_name(std::map& tensors, const std::string prefix) { + input_hint_block.map_by_name(tensors, ""); + tensors[prefix + "time_embed.0.weight"] = time_embed_0_w; + tensors[prefix + "time_embed.0.bias"] = time_embed_0_b; + tensors[prefix + "time_embed.2.weight"] = time_embed_2_w; + tensors[prefix + "time_embed.2.bias"] = time_embed_2_b; + + // input_blocks + tensors[prefix + "input_blocks.0.0.weight"] = input_block_0_w; + tensors[prefix + "input_blocks.0.0.bias"] = input_block_0_b; + + int len_mults = channel_mult.size(); + int input_block_idx = 0; + int ds = 1; + for (int i = 0; i < len_mults; i++) { + for (int j = 0; j < num_res_blocks; j++) { + input_block_idx += 1; + input_res_blocks[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + input_transformers[i][j].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".1."); + } + } + if (i != len_mults - 1) { + input_block_idx += 1; + input_down_samples[i].map_by_name(tensors, prefix + "input_blocks." + std::to_string(input_block_idx) + ".0."); + ds *= 2; + } + } + + for (int i = 0; i < num_zero_convs; i++) { + tensors[prefix + "zero_convs."+ std::to_string(i) + ".0.weight"] = zero_convs[i].conv_w; + tensors[prefix + "zero_convs."+ std::to_string(i) + ".0.bias"] = zero_convs[i].conv_b; + } + + // middle_blocks + middle_block_0.map_by_name(tensors, prefix + "middle_block.0."); + middle_block_1.map_by_name(tensors, prefix + "middle_block.1."); + middle_block_2.map_by_name(tensors, prefix + "middle_block.2."); + + tensors[prefix + "middle_block_out.0.weight"] = middle_block_out_w; + tensors[prefix + "middle_block_out.0.bias"] = middle_block_out_b; + } + + struct ggml_cgraph* build_graph_hint(struct ggml_tensor* hint) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + + struct ggml_context* ctx0 = ggml_init(params); + struct ggml_cgraph* gf = ggml_new_graph(ctx0); + // temporal tensors for transfer tensors from cpu to gpu if needed + struct ggml_tensor* hint_t = NULL; + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + hint_t = ggml_dup_tensor(ctx0, hint); + ggml_allocr_alloc(compute_allocr, hint_t); + // pass data to device backend + if (!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); + } + } else { + // if it's cpu backend just pass the same tensors + hint_t = hint; + } + struct ggml_tensor* out = input_hint_block.forward(ctx0, hint_t); + ggml_build_forward_expand(gf, out); + ggml_free(ctx0); + return gf; + } + + void process_hint(struct ggml_tensor* output, int n_threads, struct ggml_tensor* hint) { + // compute buffer size + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph_hint(hint); + }; + GGMLModule::alloc_compute_buffer(get_graph); + // perform computation + GGMLModule::compute(get_graph, n_threads, output); + GGMLModule::free_compute_buffer(); + } + + void forward(struct ggml_cgraph* gf, + struct ggml_context* ctx0, + struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + // x: [N, in_channels, h, w] + // timesteps: [N, ] + // t_emb: [N, model_channels] + // context: [N, max_position, hidden_size]([N, 77, 768]) + if (t_emb == NULL && timesteps != NULL) { + t_emb = new_timestep_embedding(ctx0, compute_allocr, timesteps, model_channels); // [N, model_channels] + } + + // time_embed = nn.Sequential + auto emb = ggml_nn_linear(ctx0, t_emb, time_embed_0_w, time_embed_0_b); + emb = ggml_silu_inplace(ctx0, emb); + emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] + + // input_blocks + int zero_conv_offset = 0; + + // input block 0 + struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] + h = ggml_add(ctx0, h, hint); + + auto h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + zero_conv_offset++; + + // input block 1-11 + int len_mults = channel_mult.size(); + int ds = 1; + for (int i = 0; i < len_mults; i++) { + int mult = channel_mult[i]; + for (int j = 0; j < num_res_blocks; j++) { + h = input_res_blocks[i][j].forward(ctx0, h, emb); // [N, mult*model_channels, h, w] + if (std::find(attention_resolutions.begin(), attention_resolutions.end(), ds) != attention_resolutions.end()) { + h = input_transformers[i][j].forward(ctx0, h, context); // [N, mult*model_channels, h, w] + } + h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + zero_conv_offset++; + } + if (i != len_mults - 1) { + ds *= 2; + h = input_down_samples[i].forward(ctx0, h); // [N, mult*model_channels, h/(2^(i+1)), w/(2^(i+1))] + h_c = ggml_nn_conv_2d(ctx0, h, zero_convs[zero_conv_offset].conv_w, zero_convs[zero_conv_offset].conv_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + zero_conv_offset++; + } + } + // [N, 4*model_channels, h/8, w/8] + + // middle_block + h = middle_block_0.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] + h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + + h_c = ggml_nn_conv_2d(ctx0, h, middle_block_out_w, middle_block_out_b); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_c, controls[zero_conv_offset])); + } + + struct ggml_cgraph* build_graph(struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* timesteps, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data + static size_t buf_size = ggml_tensor_overhead() * CONTROL_NET_GRAPH_SIZE + ggml_graph_overhead(); + static std::vector buf(buf_size); + + struct ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/buf.data(), + /*.no_alloc =*/true, // the tensors will be allocated later by ggml_allocr_alloc_graph() + }; + // LOG_DEBUG("mem_size %u ", params.mem_size); + + struct ggml_context* ctx0 = ggml_init(params); + + struct ggml_cgraph* gf = ggml_new_graph_custom(ctx0, CONTROL_NET_GRAPH_SIZE, false); + + // temporal tensors for transfer tensors from cpu to gpu if needed + struct ggml_tensor* x_t = NULL; + struct ggml_tensor* hint_t = NULL; + struct ggml_tensor* timesteps_t = NULL; + struct ggml_tensor* context_t = NULL; + struct ggml_tensor* t_emb_t = NULL; + + // it's performing a compute, check if backend isn't cpu + if (!ggml_backend_is_cpu(backend)) { + // pass input tensors to gpu memory + x_t = ggml_dup_tensor(ctx0, x); + context_t = ggml_dup_tensor(ctx0, context); + hint_t = ggml_dup_tensor(ctx0, hint); + ggml_allocr_alloc(compute_allocr, x_t); + if (timesteps != NULL) { + timesteps_t = ggml_dup_tensor(ctx0, timesteps); + ggml_allocr_alloc(compute_allocr, timesteps_t); + } + ggml_allocr_alloc(compute_allocr, context_t); + ggml_allocr_alloc(compute_allocr, hint_t); + if (t_emb != NULL) { + t_emb_t = ggml_dup_tensor(ctx0, t_emb); + ggml_allocr_alloc(compute_allocr, t_emb_t); + } + // pass data to device backend + if (!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set(hint_t, hint->data, 0, ggml_nbytes(hint)); + if (timesteps_t != NULL) { + ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + } + if (t_emb_t != NULL) { + ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + } + } + } else { + // if it's cpu backend just pass the same tensors + x_t = x; + timesteps_t = timesteps; + context_t = context; + t_emb_t = t_emb; + hint_t = hint; + } + + forward(gf, ctx0, x_t, hint_t, timesteps_t, context_t, t_emb_t); + + ggml_free(ctx0); + + return gf; + } + + void alloc_compute_buffer(struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + { + struct ggml_init_params params; + params.mem_size = static_cast(14 * ggml_tensor_overhead()) + 256; + params.mem_buffer = NULL; + params.no_alloc = true; + control_ctx = ggml_init(params); + size_t control_buffer_size = 0; + int w = x->ne[0], h = x->ne[1], steps = 0; + for(int i = 0; i < (num_zero_convs + 1); i++) { + bool last = i == num_zero_convs; + int c = last ? middle_out_channel : zero_convs[i].channels; + if(!last && steps == 3) { + w /= 2; h /= 2; steps = 0; + } + controls.push_back(ggml_new_tensor_4d(control_ctx, GGML_TYPE_F32, w, h, c, 1)); + control_buffer_size += ggml_nbytes(controls[i]); + steps++; + } + control_buffer = ggml_backend_alloc_ctx_tensors(control_ctx, backend); + } + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, hint, NULL, context, t_emb); + }; + GGMLModule::alloc_compute_buffer(get_graph); + } + + void compute(int n_threads, + struct ggml_tensor* x, + struct ggml_tensor* hint, + struct ggml_tensor* context, + struct ggml_tensor* t_emb = NULL) { + auto get_graph = [&]() -> struct ggml_cgraph* { + return build_graph(x, hint, NULL, context, t_emb); + }; + GGMLModule::compute(get_graph, n_threads, NULL); + } + + void free_compute_buffer() { + GGMLModule::free_compute_buffer(); + ggml_free(control_ctx); + ggml_backend_buffer_free(control_buffer); + control_buffer = NULL; + } +}; + +#endif // __CONTROL_HPP__ \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index af2c337d..b8b4a46b 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -7,6 +7,7 @@ #include #include "stable-diffusion.h" +#include "preprocessing.hpp" #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -60,10 +61,13 @@ struct SDParams { std::string vae_path; std::string taesd_path; std::string esrgan_path; + std::string controlnet_path; + std::string embeddings_path; sd_type_t wtype = SD_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; + std::string control_image_path; std::string prompt; std::string negative_prompt; @@ -77,24 +81,15 @@ struct SDParams { schedule_t schedule = DEFAULT; int sample_steps = 20; float strength = 0.75f; + float control_strength = 0.9f; rng_type_t rng_type = CUDA_RNG; int64_t seed = 42; bool verbose = false; bool vae_tiling = false; + bool control_net_cpu = false; + bool canny_preprocess = false; }; -static std::string sd_basename(const std::string& path) { - size_t pos = path.find_last_of('/'); - if (pos != std::string::npos) { - return path.substr(pos + 1); - } - pos = path.find_last_of('\\'); - if (pos != std::string::npos) { - return path.substr(pos + 1); - } - return path; -} - void print_params(SDParams params) { printf("Option: \n"); printf(" n_threads: %d\n", params.n_threads); @@ -104,8 +99,13 @@ void print_params(SDParams params) { printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); + printf(" controlnet_path: %s\n", params.controlnet_path.c_str()); + printf(" embeddings_path: %s\n", params.embeddings_path.c_str()); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); + printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); + printf(" strength(control): %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" cfg_scale: %.2f\n", params.cfg_scale); @@ -133,16 +133,20 @@ void print_usage(int argc, const char* argv[]) { printf(" -m, --model [MODEL] path to model\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --control-net [CONTROL_PATH] path to control net model\n"); + printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); + printf(" --control-image [IMAGE] path to image condition, control net\n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); + printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); @@ -156,6 +160,8 @@ void print_usage(int argc, const char* argv[]) { printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); + printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); + printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" -v, --verbose print extra info\n"); } @@ -207,13 +213,25 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.taesd_path = argv[i]; + } else if (arg == "--control-net") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.controlnet_path = argv[i]; } else if (arg == "--upscale-model") { if (++i >= argc) { invalid_arg = true; break; } params.esrgan_path = argv[i]; - } else if (arg == "--type") { + } else if (arg == "--embd-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.embeddings_path = argv[i]; + } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; break; @@ -250,6 +268,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.input_path = argv[i]; + } else if (arg == "--control-image") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.control_image_path = argv[i]; } else if (arg == "-o" || arg == "--output") { if (++i >= argc) { invalid_arg = true; @@ -280,6 +304,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.strength = std::stof(argv[i]); + } else if (arg == "--control-strength") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.control_strength = std::stof(argv[i]); } else if (arg == "-H" || arg == "--height") { if (++i >= argc) { invalid_arg = true; @@ -306,6 +336,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.clip_skip = std::stoi(argv[i]); } else if (arg == "--vae-tiling") { params.vae_tiling = true; + } else if (arg == "--control-net-cpu") { + params.control_net_cpu = true; + } else if (arg == "--canny") { + params.canny_preprocess = true; } else if (arg == "-b" || arg == "--batch-count") { if (++i >= argc) { invalid_arg = true; @@ -536,14 +570,17 @@ int main(int argc, const char* argv[]) { sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), params.vae_path.c_str(), params.taesd_path.c_str(), + params.controlnet_path.c_str(), params.lora_model_dir.c_str(), + params.embeddings_path.c_str(), vae_decode_only, params.vae_tiling, true, params.n_threads, params.wtype, params.rng_type, - params.schedule); + params.schedule, + params.control_net_cpu); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); @@ -552,6 +589,23 @@ int main(int argc, const char* argv[]) { sd_image_t* results; if (params.mode == TXT2IMG) { + sd_image_t* control_image = NULL; + if(params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) { + int c = 0; + input_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); + if(input_image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", params.control_image_path.c_str()); + return 1; + } + control_image = new sd_image_t{(uint32_t)params.width, + (uint32_t)params.height, + 3, + input_image_buffer}; + if(params.canny_preprocess) { // apply preprocessor + LOG_INFO("Applying canny preprocessor"); + control_image->data = preprocess_canny(control_image->data, control_image->width, control_image->height); + } + } results = txt2img(sd_ctx, params.prompt.c_str(), params.negative_prompt.c_str(), @@ -562,7 +616,9 @@ int main(int argc, const char* argv[]) { params.sample_method, params.sample_steps, params.seed, - params.batch_count); + params.batch_count, + control_image, + params.control_strength); } else { sd_image_t input_image = {(uint32_t)params.width, (uint32_t)params.height, diff --git a/ggml_extend.hpp b/ggml_extend.hpp index b48c949e..60ab430c 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -219,7 +219,7 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, for (int iy = 0; iy < height; iy++) { for (int ix = 0; ix < width; ix++) { for (int k = 0; k < channels; k++) { - float value = *(image_data + iy * width * channels + ix * channels + k); + int value = *(image_data + iy * width * channels + ix * channels + k); ggml_tensor_set_f32(output, value / 255.0f, ix, iy, k); } } @@ -462,8 +462,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { #ifdef SD_USE_CUBLAS - ggml_backend_tensor_get_async(backend, tensor, data, offset, size); - ggml_backend_synchronize(backend); + if(!ggml_backend_is_cpu(backend)) { + ggml_backend_tensor_get_async(backend, tensor, data, offset, size); + ggml_backend_synchronize(backend); + } else { + ggml_backend_tensor_get(tensor, data, offset, size); + } #else ggml_backend_tensor_get(tensor, data, offset, size); #endif @@ -544,7 +548,7 @@ struct GGMLModule { bool alloc_params_buffer(ggml_backend_t backend_, ggml_type wtype_ = GGML_TYPE_F32) { backend = backend_; wtype = wtype_; - params_buffer_size = 10 * 1024 * 1024; // 10 MB, for padding + params_buffer_size = 4 * 1024 * 1024; // 10 MB, for padding params_buffer_size += calculate_mem_size(); size_t num_tensors = get_num_tensors(); diff --git a/model.cpp b/model.cpp index 387a9cf5..ab3463d7 100644 --- a/model.cpp +++ b/model.cpp @@ -89,7 +89,6 @@ const char* unused_tensors[] = { "model_ema.decay", "model_ema.num_updates", "model_ema.diffusion_model", - "control_model", "embedding_manager", "denoiser.sigmas", }; @@ -376,6 +375,11 @@ std::string convert_tensor_name(const std::string& name) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); + } else if (starts_with(name, "control_model.")) { // for controlnet pth models + size_t pos = name.find('.'); + if (pos != std::string::npos) { + new_name = name.substr(pos + 1); + } } else if (starts_with(name, "lora_")) { // for lora size_t pos = name.find('.'); if (pos != std::string::npos) { @@ -1329,11 +1333,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend size_t nbytes_to_read = tensor_storage.nbytes_to_read(); - if (dst_tensor->buffer == NULL || ggml_backend_is_cpu(backend) -#ifdef SD_USE_METAL - || ggml_backend_is_metal(backend) -#endif - ) { + if (dst_tensor->buffer == NULL || ggml_backend_buffer_is_host(dst_tensor->buffer)) { // for the CPU and Metal backend, we can copy directly into the tensor if (tensor_storage.type == dst_tensor->type) { GGML_ASSERT(ggml_nbytes(dst_tensor) == tensor_storage.nbytes()); diff --git a/model.h b/model.h index 4b692a30..13665a7e 100644 --- a/model.h +++ b/model.h @@ -93,7 +93,6 @@ struct TensorStorage { }; typedef std::function on_new_tensor_cb_t; -typedef std::function on_new_token_cb_t; class ModelLoader { protected: diff --git a/preprocessing.hpp b/preprocessing.hpp new file mode 100644 index 00000000..d5bbd564 --- /dev/null +++ b/preprocessing.hpp @@ -0,0 +1,229 @@ +#ifndef __PREPROCESSING_HPP__ +#define __PREPROCESSING_HPP__ + +#include "ggml_extend.hpp" +#define M_PI_ 3.14159265358979323846 + +void convolve(struct ggml_tensor* input, struct ggml_tensor* output, struct ggml_tensor* kernel, int padding) { + struct ggml_init_params params; + params.mem_size = 20 * 1024 * 1024; // 10 + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* ctx0 = ggml_init(params); + struct ggml_tensor* kernel_fp16 = ggml_new_tensor_4d(ctx0, GGML_TYPE_F16, kernel->ne[0], kernel->ne[1], 1, 1); + ggml_fp32_to_fp16_row((float*)kernel->data, (ggml_fp16_t*) kernel_fp16->data, ggml_nelements(kernel)); + ggml_tensor* h = ggml_conv_2d(ctx0, kernel_fp16, input, 1, 1, padding, padding, 1, 1); + ggml_cgraph* gf = ggml_new_graph(ctx0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h, output)); + ggml_graph_compute_with_ctx(ctx0, gf, 1); + ggml_free(ctx0); +} + +void gaussian_kernel(struct ggml_tensor* kernel) { + int ks_mid = kernel->ne[0] / 2; + float sigma = 1.4f; + float normal = 1.f / (2.0f * M_PI_ * powf(sigma, 2.0f)); + for(int y = 0; y < kernel->ne[0]; y++) { + float gx = -ks_mid + y; + for(int x = 0; x < kernel->ne[1]; x++) { + float gy = -ks_mid + x; + float k_ = expf(-((gx*gx + gy*gy) / (2.0f * powf(sigma, 2.0f)))) * normal; + ggml_tensor_set_f32(kernel, k_, x, y); + } + } +} + +void grayscale(struct ggml_tensor* rgb_img, struct ggml_tensor* grayscale) { + for (int iy = 0; iy < rgb_img->ne[1]; iy++) { + for (int ix = 0; ix < rgb_img->ne[0]; ix++) { + float r = ggml_tensor_get_f32(rgb_img, ix, iy); + float g = ggml_tensor_get_f32(rgb_img, ix, iy, 1); + float b = ggml_tensor_get_f32(rgb_img, ix, iy, 2); + float gray = 0.2989f * r + 0.5870f * g + 0.1140f * b; + ggml_tensor_set_f32(grayscale, gray, ix, iy); + } + } +} + +void prop_hypot(struct ggml_tensor* x, struct ggml_tensor* y, struct ggml_tensor* h) { + int n_elements = ggml_nelements(h); + float* dx = (float*)x->data; + float* dy = (float*)y->data; + float* dh = (float*)h->data; + for (int i = 0; i data; + float* dy = (float*)y->data; + float* dh = (float*)h->data; + for (int i = 0; i < n_elements; i++) { + dh[i] = atan2f(dy[i], dx[i]); + } +} + +void normalize_tensor(struct ggml_tensor* g) { + int n_elements = ggml_nelements(g); + float* dg = (float*)g->data; + float max = -INFINITY; + for (int i = 0; i max ? dg[i] : max; + } + max = 1.0f / max; + for (int i = 0; i ne[1] - 1; iy++) { + for (int ix = 1; ix < result->ne[0] - 1; ix++) { + float angle = ggml_tensor_get_f32(D, ix, iy) * 180.0f / M_PI_; + angle = angle < 0.0f ? angle += 180.0f : angle; + float q = 1.0f; + float r = 1.0f; + + // angle 0 + if((0 >= angle && angle < 22.5f) || (157.5f >= angle && angle <= 180)){ + q = ggml_tensor_get_f32(G, ix, iy + 1); + r = ggml_tensor_get_f32(G, ix, iy - 1); + } + // angle 45 + else if (22.5f >= angle && angle < 67.5f) { + q = ggml_tensor_get_f32(G, ix + 1, iy - 1); + r = ggml_tensor_get_f32(G, ix - 1, iy + 1); + } + // angle 90 + else if (67.5f >= angle && angle < 112.5) { + q = ggml_tensor_get_f32(G, ix + 1, iy); + r = ggml_tensor_get_f32(G, ix - 1, iy); + } + // angle 135 + else if (112.5 >= angle && angle < 157.5f) { + q = ggml_tensor_get_f32(G, ix - 1, iy - 1); + r = ggml_tensor_get_f32(G, ix + 1, iy + 1); + } + + float cur = ggml_tensor_get_f32(G, ix, iy); + if ((cur >= q) && (cur >= r)) { + ggml_tensor_set_f32(result, cur, ix, iy); + } else { + ggml_tensor_set_f32(result, 0.0f, ix, iy); + } + } + } +} + +void threshold_hystersis(struct ggml_tensor* img, float highThreshold, float lowThreshold, float weak, float strong) { + int n_elements = ggml_nelements(img); + float* imd = (float*)img->data; + float max = -INFINITY; + for (int i = 0; i < n_elements; i++) { + max = imd[i] > max ? imd[i] : max; + } + float ht = max * highThreshold; + float lt = ht * lowThreshold; + for (int i = 0; i < n_elements; i++) { + float img_v = imd[i]; + if(img_v >= ht) { // strong pixel + imd[i] = strong; + } else if(img_v <= ht && img_v >= lt) { // strong pixel + imd[i] = weak; + } + } + + for (int iy = 0; iy < img->ne[1]; iy++) { + for (int ix = 0; ix < img->ne[0]; ix++) { + if(ix >= 3 && ix <= img->ne[0] - 3 && iy >= 3 && iy <= img->ne[1] - 3) { + ggml_tensor_set_f32(img, ggml_tensor_get_f32(img, ix, iy), ix, iy); + } else { + ggml_tensor_set_f32(img, 0.0f, ix, iy); + } + } + } + + // hysteresis + for (int iy = 1; iy < img->ne[1] - 1; iy++) { + for (int ix = 1; ix < img->ne[0] - 1; ix++) { + float imd_v = ggml_tensor_get_f32(img, ix, iy); + if(imd_v == weak) { + if(ggml_tensor_get_f32(img, ix + 1, iy - 1) == strong || ggml_tensor_get_f32(img, ix + 1, iy) == strong || + ggml_tensor_get_f32(img, ix, iy - 1) == strong || ggml_tensor_get_f32(img, ix, iy + 1) == strong || + ggml_tensor_get_f32(img, ix - 1, iy - 1) == strong || ggml_tensor_get_f32(img, ix - 1, iy) == strong) { + ggml_tensor_set_f32(img, strong, ix, iy); + } else { + ggml_tensor_set_f32(img, 0.0f, ix, iy); + } + } + } + } +} + +uint8_t* preprocess_canny(uint8_t* img, int width, int height, float highThreshold = 0.08f, float lowThreshold = 0.08f, float weak = 0.8f, float strong = 1.0f, bool inverse = false) { + struct ggml_init_params params; + params.mem_size = static_cast(10 * 1024 * 1024); // 10 + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* work_ctx = ggml_init(params); + + if (!work_ctx) { + LOG_ERROR("ggml_init() failed"); + return NULL; + } + + float kX[9] = { + -1, 0, 1, + -2, 0, 2, + -1, 0, 1 + }; + + float kY[9] = { + 1, 2, 1, + 0, 0, 0, + -1, -2, -1 + }; + + // generate kernel + int kernel_size = 5; + struct ggml_tensor* gkernel = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, kernel_size, kernel_size, 1, 1); + struct ggml_tensor* sf_kx = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1); + memcpy(sf_kx->data, kX, ggml_nbytes(sf_kx)); + struct ggml_tensor* sf_ky = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 3, 3, 1, 1); + memcpy(sf_ky->data, kY, ggml_nbytes(sf_ky)); + gaussian_kernel(gkernel); + struct ggml_tensor* image = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + struct ggml_tensor* image_gray = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); + struct ggml_tensor* iX = ggml_dup_tensor(work_ctx, image_gray); + struct ggml_tensor* iY = ggml_dup_tensor(work_ctx, image_gray); + struct ggml_tensor* G = ggml_dup_tensor(work_ctx, image_gray); + struct ggml_tensor* tetha = ggml_dup_tensor(work_ctx, image_gray); + sd_image_to_tensor(img, image); + grayscale(image, image_gray); + convolve(image_gray, image_gray, gkernel, 2); + convolve(image_gray, iX, sf_kx, 1); + convolve(image_gray, iY, sf_ky, 1); + prop_hypot(iX, iY, G); + normalize_tensor(G); + prop_arctan2(iX, iY, tetha); + non_max_supression(image_gray, G, tetha); + threshold_hystersis(image_gray, highThreshold, lowThreshold, weak, strong); + // to RGB channels + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + float gray = ggml_tensor_get_f32(image_gray, ix, iy); + gray = inverse ? 1.0f - gray : gray; + ggml_tensor_set_f32(image, gray, ix, iy); + ggml_tensor_set_f32(image, gray, ix, iy, 1); + ggml_tensor_set_f32(image, gray, ix, iy, 2); + } + } + free(img); + uint8_t* output = sd_tensor_to_image(image); + ggml_free(work_ctx); + return output; +} + +#endif // __PREPROCESSING_HPP__ \ No newline at end of file diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 10e24585..ee67dd43 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -8,6 +8,7 @@ #include "clip.hpp" #include "denoiser.hpp" +#include "control.hpp" #include "esrgan.hpp" #include "lora.hpp" #include "tae.hpp" @@ -80,6 +81,8 @@ class StableDiffusionGGML { TinyAutoEncoder tae_first_stage; std::string taesd_path; + ControlNet control_net; + StableDiffusionGGML() = default; StableDiffusionGGML(int n_threads, @@ -106,13 +109,14 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, + const std::string control_net_path, + const std::string embeddings_path, const std::string& taesd_path, - bool vae_tiling, + bool vae_tiling_, ggml_type wtype, - schedule_t schedule) { - this->use_tiny_autoencoder = taesd_path.size() > 0; - this->taesd_path = taesd_path; - this->vae_tiling = vae_tiling; + schedule_t schedule, + bool control_net_cpu) { + use_tiny_autoencoder = taesd_path.size() > 0; #ifdef SD_USE_CUBLAS LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); @@ -137,6 +141,8 @@ class StableDiffusionGGML { LOG_INFO("loading model from '%s'", model_path.c_str()); ModelLoader model_loader; + vae_tiling = vae_tiling_; + if (!model_loader.init_from_file(model_path)) { LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); return false; @@ -186,6 +192,8 @@ class StableDiffusionGGML { return false; } + cond_stage_model.text_model.embd_dir = embeddings_path; + ggml_type vae_type = model_data_type; if (version == VERSION_XL) { vae_type = GGML_TYPE_F32; // avoid nan, not work... @@ -308,8 +316,23 @@ class StableDiffusionGGML { denoiser->schedule->sigmas[i] = std::sqrt((1 - denoiser->schedule->alphas_cumprod[i]) / denoiser->schedule->alphas_cumprod[i]); denoiser->schedule->log_sigmas[i] = std::log(denoiser->schedule->sigmas[i]); } + LOG_DEBUG("finished loaded file"); ggml_free(ctx); + + if(control_net_path.size() > 0) { + ggml_backend_t cn_backend = NULL; + if(control_net_cpu && !ggml_backend_is_cpu(backend)) { + LOG_DEBUG("ControlNet: Using CPU backend"); + cn_backend = ggml_backend_cpu_init(); + } else { + cn_backend = backend; + } + if(!control_net.load_from_file(control_net_path, cn_backend, GGML_TYPE_F16 /* just f16 controlnet models */)) { + return false; + } + } + if (use_tiny_autoencoder) { return tae_first_stage.load_from_file(taesd_path, backend); } @@ -329,8 +352,9 @@ class StableDiffusionGGML { ggml_set_f32(timesteps, 999); set_timestep_embedding(timesteps, t_emb, diffusion_model.model_channels); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model.alloc_compute_buffer(x_t, c, t_emb); - diffusion_model.compute(out, n_threads, x_t, NULL, c, t_emb); + std::vector controls; + diffusion_model.alloc_compute_buffer(x_t, c, controls, t_emb); + diffusion_model.compute(out, n_threads, x_t, NULL, c, controls, 1.0f, t_emb); diffusion_model.free_compute_buffer(); double result = 0.f; @@ -511,9 +535,11 @@ class StableDiffusionGGML { ggml_tensor* c_vector, ggml_tensor* uc, ggml_tensor* uc_vector, + ggml_tensor* control_hint, float cfg_scale, sample_method_t method, - const std::vector& sigmas) { + const std::vector& sigmas, + float control_strength) { size_t steps = sigmas.size() - 1; // x_t = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(x_t); @@ -523,7 +549,14 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x_t); struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); // [N, ] struct ggml_tensor* t_emb = new_timestep_embedding(work_ctx, NULL, timesteps, diffusion_model.model_channels); // [N, model_channels] - diffusion_model.alloc_compute_buffer(noised_input, c, t_emb, c_vector); + struct ggml_tensor* guided_hint = NULL; + if(control_hint != NULL) { + guided_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, noised_input->ne[0], noised_input->ne[1], diffusion_model.model_channels, 1); + control_net.process_hint(guided_hint, n_threads, control_hint); + control_net.alloc_compute_buffer(noised_input, guided_hint, c, t_emb); + } + + diffusion_model.alloc_compute_buffer(noised_input, c, control_net.controls, t_emb, c_vector); bool has_unconditioned = cfg_scale != 1.0 && uc != NULL; @@ -573,12 +606,19 @@ class StableDiffusionGGML { ggml_tensor_scale(noised_input, c_in); // cond - diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, t_emb, c_vector); + if(control_hint != NULL) { + control_net.compute(n_threads, noised_input, guided_hint, c, t_emb); + } + diffusion_model.compute(out_cond, n_threads, noised_input, NULL, c, control_net.controls, control_strength, t_emb, c_vector); float* negative_data = NULL; if (has_unconditioned) { // uncond - diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, t_emb, uc_vector); + if(control_hint != NULL) { + control_net.compute(n_threads, noised_input, guided_hint, uc, t_emb); + } + + diffusion_model.compute(out_uncond, n_threads, noised_input, NULL, uc, control_net.controls, control_strength, t_emb, uc_vector); negative_data = (float*)out_uncond->data; } float* vec_denoised = (float*)denoised->data; @@ -987,6 +1027,7 @@ class StableDiffusionGGML { LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); abort(); } + control_net.free_compute_buffer(); diffusion_model.free_compute_buffer(); return x; } @@ -1098,14 +1139,17 @@ struct sd_ctx_t { sd_ctx_t* new_sd_ctx(const char* model_path_c_str, const char* vae_path_c_str, const char* taesd_path_c_str, + const char* control_net_path_c_str, const char* lora_model_dir_c_str, + const char* embed_dir_c_str, bool vae_decode_only, bool vae_tiling, bool free_params_immediately, int n_threads, enum sd_type_t wtype, enum rng_type_t rng_type, - enum schedule_t s) { + enum schedule_t s, + bool keep_control_net_cpu) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; @@ -1113,6 +1157,8 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, std::string model_path(model_path_c_str); std::string vae_path(vae_path_c_str); std::string taesd_path(taesd_path_c_str); + std::string control_net_path(control_net_path_c_str); + std::string embd_path(embed_dir_c_str); std::string lora_model_dir(lora_model_dir_c_str); sd_ctx->sd = new StableDiffusionGGML(n_threads, @@ -1126,10 +1172,13 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, if (!sd_ctx->sd->load_from_file(model_path, vae_path, + control_net_path, + embd_path, taesd_path, vae_tiling, (ggml_type)wtype, - s)) { + s, + keep_control_net_cpu)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1156,7 +1205,9 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, enum sample_method_t sample_method, int sample_steps, int64_t seed, - int batch_count) { + int batch_count, + const sd_image_t* control_cond, + float control_strength) { LOG_DEBUG("txt2img %dx%d", width, height); if (sd_ctx == NULL) { return NULL; @@ -1224,6 +1275,12 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, sd_ctx->sd->cond_stage_model.free_params_buffer(); } + struct ggml_tensor* image_hint = NULL; + if(control_cond != NULL) { + image_hint = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_image_to_tensor(control_cond->data, image_hint); + } + std::vector final_latents; // collect latents to decode int C = 4; int W = width / 8; @@ -1240,7 +1297,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->schedule->get_sigmas(sample_steps); - struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, cfg_scale, sample_method, sigmas); + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, NULL, c, c_vector, uc, uc_vector, image_hint, cfg_scale, sample_method, sigmas, control_strength); // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1393,8 +1450,8 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); - struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, uc_vector, - cfg_scale, sample_method, sigma_sched); + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, init_latent, noise, c, c_vector, uc, + uc_vector, NULL, cfg_scale, sample_method, sigma_sched, 1.0f); // struct ggml_tensor *x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t t3 = ggml_time_ms(); diff --git a/stable-diffusion.h b/stable-diffusion.h index a18ee4a3..c719a99b 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -105,14 +105,17 @@ typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* vae_path, const char* taesd_path, + const char* control_net_path_c_str, const char* lora_model_dir, + const char* embed_dir_c_str, bool vae_decode_only, bool vae_tiling, bool free_params_immediately, int n_threads, enum sd_type_t wtype, enum rng_type_t rng_type, - enum schedule_t s); + enum schedule_t s, + bool keep_control_net_cpu); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); @@ -126,7 +129,9 @@ SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, enum sample_method_t sample_method, int sample_steps, int64_t seed, - int batch_count); + int batch_count, + const sd_image_t* control_cond, + float control_strength); SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, sd_image_t init_image, diff --git a/unet.hpp b/unet.hpp index 6b6e7439..2c9e7c92 100644 --- a/unet.hpp +++ b/unet.hpp @@ -3,468 +3,12 @@ #include "common.hpp" #include "ggml_extend.hpp" +#include "model.h" /*==================================================== UnetModel =====================================================*/ #define UNET_GRAPH_SIZE 10240 -struct ResBlock { - // network hparams - int channels; // model_channels * (1, 1, 1, 2, 2, 4, 4, 4) - int emb_channels; // time_embed_dim - int out_channels; // mult * model_channels - - // network params - // in_layers - struct ggml_tensor* in_layer_0_w; // [channels, ] - struct ggml_tensor* in_layer_0_b; // [channels, ] - // in_layer_1 is nn.SILU() - struct ggml_tensor* in_layer_2_w; // [out_channels, channels, 3, 3] - struct ggml_tensor* in_layer_2_b; // [out_channels, ] - - // emb_layers - // emb_layer_0 is nn.SILU() - struct ggml_tensor* emb_layer_1_w; // [out_channels, emb_channels] - struct ggml_tensor* emb_layer_1_b; // [out_channels, ] - - // out_layers - struct ggml_tensor* out_layer_0_w; // [out_channels, ] - struct ggml_tensor* out_layer_0_b; // [out_channels, ] - // out_layer_1 is nn.SILU() - // out_layer_2 is nn.Dropout(), p = 0 for inference - struct ggml_tensor* out_layer_3_w; // [out_channels, out_channels, 3, 3] - struct ggml_tensor* out_layer_3_b; // [out_channels, ] - - // skip connection, only if out_channels != channels - struct ggml_tensor* skip_w; // [out_channels, channels, 1, 1] - struct ggml_tensor* skip_b; // [out_channels, ] - - size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 2 * channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_0_w/b - mem_size += out_channels * channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // in_layer_2_w - mem_size += 5 * out_channels * ggml_type_sizef(GGML_TYPE_F32); // in_layer_2_b/emb_layer_1_b/out_layer_0_w/out_layer_0_b/out_layer_3_b - mem_size += out_channels * emb_channels * ggml_type_sizef(wtype); // emb_layer_1_w - mem_size += out_channels * out_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // out_layer_3_w - - if (out_channels != channels) { - mem_size += out_channels * channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // skip_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // skip_b - } - return static_cast(mem_size); - } - - void init_params(struct ggml_context* ctx, ggml_type wtype) { - in_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); - in_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, channels); - in_layer_2_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, channels, out_channels); - in_layer_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - emb_layer_1_w = ggml_new_tensor_2d(ctx, wtype, emb_channels, out_channels); - emb_layer_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - out_layer_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - out_layer_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - out_layer_3_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 3, 3, out_channels, out_channels); - out_layer_3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - - if (out_channels != channels) { - skip_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, channels, out_channels); - skip_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); - } - } - - void map_by_name(std::map& tensors, const std::string prefix) { - tensors[prefix + "in_layers.0.weight"] = in_layer_0_w; - tensors[prefix + "in_layers.0.bias"] = in_layer_0_b; - tensors[prefix + "in_layers.2.weight"] = in_layer_2_w; - tensors[prefix + "in_layers.2.bias"] = in_layer_2_b; - - tensors[prefix + "emb_layers.1.weight"] = emb_layer_1_w; - tensors[prefix + "emb_layers.1.bias"] = emb_layer_1_b; - - tensors[prefix + "out_layers.0.weight"] = out_layer_0_w; - tensors[prefix + "out_layers.0.bias"] = out_layer_0_b; - tensors[prefix + "out_layers.3.weight"] = out_layer_3_w; - tensors[prefix + "out_layers.3.bias"] = out_layer_3_b; - - if (out_channels != channels) { - tensors[prefix + "skip_connection.weight"] = skip_w; - tensors[prefix + "skip_connection.bias"] = skip_b; - } - } - - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* emb) { - // x: [N, channels, h, w] - // emb: [N, emb_channels] - - // in_layers - auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b); - h = ggml_silu_inplace(ctx, h); - h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w] - - // emb_layers - auto emb_out = ggml_silu(ctx, emb); - emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels] - emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] - - // out_layers - h = ggml_add(ctx, h, emb_out); - h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b); - h = ggml_silu_inplace(ctx, h); - - // dropout, skip for inference - - h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w] - - // skip connection - if (out_channels != channels) { - x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w] - } - - h = ggml_add(ctx, h, x); - return h; // [N, out_channels, h, w] - } -}; - -struct SpatialTransformer { - int in_channels; // mult * model_channels - int n_head; // num_heads - int d_head; // in_channels // n_heads - int depth = 1; // 1 - int context_dim = 768; // hidden_size, 1024 for VERSION_2_x - - // group norm - struct ggml_tensor* norm_w; // [in_channels,] - struct ggml_tensor* norm_b; // [in_channels,] - - // proj_in - struct ggml_tensor* proj_in_w; // [in_channels, in_channels, 1, 1] - struct ggml_tensor* proj_in_b; // [in_channels,] - - // transformer - struct Transformer { - // layer norm 1 - struct ggml_tensor* norm1_w; // [in_channels, ] - struct ggml_tensor* norm1_b; // [in_channels, ] - - // attn1 - struct ggml_tensor* attn1_q_w; // [in_channels, in_channels] - struct ggml_tensor* attn1_k_w; // [in_channels, in_channels] - struct ggml_tensor* attn1_v_w; // [in_channels, in_channels] - - struct ggml_tensor* attn1_out_w; // [in_channels, in_channels] - struct ggml_tensor* attn1_out_b; // [in_channels, ] - - // layer norm 2 - struct ggml_tensor* norm2_w; // [in_channels, ] - struct ggml_tensor* norm2_b; // [in_channels, ] - - // attn2 - struct ggml_tensor* attn2_q_w; // [in_channels, in_channels] - struct ggml_tensor* attn2_k_w; // [in_channels, context_dim] - struct ggml_tensor* attn2_v_w; // [in_channels, context_dim] - - struct ggml_tensor* attn2_out_w; // [in_channels, in_channels] - struct ggml_tensor* attn2_out_b; // [in_channels, ] - - // layer norm 3 - struct ggml_tensor* norm3_w; // [in_channels, ] - struct ggml_tensor* norm3_b; // [in_channels, ] - - // ff - struct ggml_tensor* ff_0_proj_w; // [in_channels * 4 * 2, in_channels] - struct ggml_tensor* ff_0_proj_b; // [in_channels * 4 * 2] - - struct ggml_tensor* ff_2_w; // [in_channels, in_channels * 4] - struct ggml_tensor* ff_2_b; // [in_channels,] - }; - - std::vector transformers; - - // proj_out - struct ggml_tensor* proj_out_w; // [in_channels, in_channels, 1, 1] - struct ggml_tensor* proj_out_b; // [in_channels,] - - SpatialTransformer(int depth = 1) - : depth(depth) { - transformers.resize(depth); - } - - int get_num_tensors() { - return depth * 20 + 7; - } - - size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b - mem_size += 2 * in_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // proj_in_w/proj_out_w - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // proj_in_b/proj_out_b - - // transformer - for (auto& transformer : transformers) { - mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1-3_w/b - mem_size += 6 * in_channels * in_channels * ggml_type_sizef(wtype); // attn1_q/k/v/out_w attn2_q/out_w - mem_size += 2 * in_channels * context_dim * ggml_type_sizef(wtype); // attn2_k/v_w - mem_size += in_channels * 4 * 2 * in_channels * ggml_type_sizef(wtype); // ff_0_proj_w - mem_size += in_channels * 4 * 2 * ggml_type_sizef(GGML_TYPE_F32); // ff_0_proj_b - mem_size += in_channels * 4 * in_channels * ggml_type_sizef(wtype); // ff_2_w - mem_size += in_channels * ggml_type_sizef(GGML_TYPE_F32); // ff_2_b - } - return static_cast(mem_size); - } - - void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { - norm_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - norm_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - proj_in_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); - proj_in_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - proj_out_w = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, 1, in_channels, in_channels); - proj_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - // transformer - for (auto& transformer : transformers) { - transformer.norm1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.attn1_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_k_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_v_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - - transformer.attn1_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn1_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.norm2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.attn2_q_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn2_k_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); - transformer.attn2_v_w = ggml_new_tensor_2d(ctx, wtype, context_dim, in_channels); - - transformer.attn2_out_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels); - transformer.attn2_out_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.norm3_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - transformer.norm3_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - - transformer.ff_0_proj_w = ggml_new_tensor_2d(ctx, wtype, in_channels, in_channels * 4 * 2); - transformer.ff_0_proj_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels * 4 * 2); - - transformer.ff_2_w = ggml_new_tensor_2d(ctx, wtype, in_channels * 4, in_channels); - transformer.ff_2_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, in_channels); - } - } - - void map_by_name(std::map& tensors, const std::string prefix) { - tensors[prefix + "norm.weight"] = norm_w; - tensors[prefix + "norm.bias"] = norm_b; - tensors[prefix + "proj_in.weight"] = proj_in_w; - tensors[prefix + "proj_in.bias"] = proj_in_b; - - // transformer - for (int i = 0; i < transformers.size(); i++) { - auto& transformer = transformers[i]; - std::string transformer_prefix = prefix + "transformer_blocks." + std::to_string(i) + "."; - tensors[transformer_prefix + "attn1.to_q.weight"] = transformer.attn1_q_w; - tensors[transformer_prefix + "attn1.to_k.weight"] = transformer.attn1_k_w; - tensors[transformer_prefix + "attn1.to_v.weight"] = transformer.attn1_v_w; - - tensors[transformer_prefix + "attn1.to_out.0.weight"] = transformer.attn1_out_w; - tensors[transformer_prefix + "attn1.to_out.0.bias"] = transformer.attn1_out_b; - - tensors[transformer_prefix + "ff.net.0.proj.weight"] = transformer.ff_0_proj_w; - tensors[transformer_prefix + "ff.net.0.proj.bias"] = transformer.ff_0_proj_b; - tensors[transformer_prefix + "ff.net.2.weight"] = transformer.ff_2_w; - tensors[transformer_prefix + "ff.net.2.bias"] = transformer.ff_2_b; - - tensors[transformer_prefix + "attn2.to_q.weight"] = transformer.attn2_q_w; - tensors[transformer_prefix + "attn2.to_k.weight"] = transformer.attn2_k_w; - tensors[transformer_prefix + "attn2.to_v.weight"] = transformer.attn2_v_w; - - tensors[transformer_prefix + "attn2.to_out.0.weight"] = transformer.attn2_out_w; - tensors[transformer_prefix + "attn2.to_out.0.bias"] = transformer.attn2_out_b; - - tensors[transformer_prefix + "norm1.weight"] = transformer.norm1_w; - tensors[transformer_prefix + "norm1.bias"] = transformer.norm1_b; - tensors[transformer_prefix + "norm2.weight"] = transformer.norm2_w; - tensors[transformer_prefix + "norm2.bias"] = transformer.norm2_b; - tensors[transformer_prefix + "norm3.weight"] = transformer.norm3_w; - tensors[transformer_prefix + "norm3.bias"] = transformer.norm3_b; - } - - tensors[prefix + "proj_out.weight"] = proj_out_w; - tensors[prefix + "proj_out.bias"] = proj_out_b; - } - - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* context) { - // x: [N, in_channels, h, w] - // context: [N, max_position, hidden_size(aka context_dim)] - auto x_in = x; - x = ggml_nn_group_norm(ctx, x, norm_w, norm_b); - // proj_in - x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w] - - // transformer - const int64_t n = x->ne[3]; - const int64_t c = x->ne[2]; - const int64_t h = x->ne[1]; - const int64_t w = x->ne[0]; - const int64_t max_position = context->ne[1]; - x = ggml_cont(ctx, ggml_permute(ctx, x, 1, 2, 0, 3)); // [N, h, w, in_channels] - - for (auto& transformer : transformers) { - auto r = x; - // layer norm 1 - x = ggml_reshape_2d(ctx, x, c, w * h * n); - x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b); - - // self-attention - { - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn1_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) - q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); -#endif - q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] - q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] - - struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn1_k_w, x); // [N * h * w, in_channels] - k = ggml_reshape_4d(ctx, k, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] - k = ggml_reshape_3d(ctx, k, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] - - struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn1_v_w, x); // [N * h * w, in_channels] - v = ggml_reshape_4d(ctx, v, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, h * w] - v = ggml_reshape_3d(ctx, v, h * w, d_head, n_head * n); // [N * n_head, d_head, h * w] - -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) - struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] -#else - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, h * w] - // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - kq = ggml_soft_max_inplace(ctx, kq); - - struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] -#endif - kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, h * w, n_head, d_head] - - // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); - x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); - - x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b); - - x = ggml_reshape_4d(ctx, x, c, w, h, n); - } - - x = ggml_add(ctx, x, r); - r = x; - - // layer norm 2 - x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b); - - // cross-attention - { - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - context = ggml_reshape_2d(ctx, context, context->ne[0], context->ne[1] * context->ne[2]); // [N * max_position, hidden_size] - struct ggml_tensor* q = ggml_mul_mat(ctx, transformer.attn2_q_w, x); // [N * h * w, in_channels] -#if !defined(SD_USE_FLASH_ATTENTION) || defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) - q = ggml_scale_inplace(ctx, q, 1.0f / sqrt((float)d_head)); -#endif - q = ggml_reshape_4d(ctx, q, d_head, n_head, h * w, n); // [N, h * w, n_head, d_head] - q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, h * w, d_head] - q = ggml_reshape_3d(ctx, q, d_head, h * w, n_head * n); // [N * n_head, h * w, d_head] - - struct ggml_tensor* k = ggml_mul_mat(ctx, transformer.attn2_k_w, context); // [N * max_position, in_channels] - k = ggml_reshape_4d(ctx, k, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] - k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, max_position, d_head] - k = ggml_reshape_3d(ctx, k, d_head, max_position, n_head * n); // [N * n_head, max_position, d_head] - - struct ggml_tensor* v = ggml_mul_mat(ctx, transformer.attn2_v_w, context); // [N * max_position, in_channels] - v = ggml_reshape_4d(ctx, v, d_head, n_head, max_position, n); // [N, max_position, n_head, d_head] - v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, max_position] - v = ggml_reshape_3d(ctx, v, max_position, d_head, n_head * n); // [N * n_head, d_head, max_position] -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) - struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head] -#else - struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position] - // kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); - kq = ggml_soft_max_inplace(ctx, kq); - - struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, h * w, d_head] -#endif - kqv = ggml_reshape_4d(ctx, kqv, d_head, h * w, n_head, n); - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); - - // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels] - x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels] - - x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b); - - x = ggml_reshape_4d(ctx, x, c, w, h, n); - } - - x = ggml_add(ctx, x, r); - r = x; - - // layer norm 3 - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b); - - // ff - { - // GEGLU - auto x_w = ggml_view_2d(ctx, - transformer.ff_0_proj_w, - transformer.ff_0_proj_w->ne[0], - transformer.ff_0_proj_w->ne[1] / 2, - transformer.ff_0_proj_w->nb[1], - 0); // [in_channels * 4, in_channels] - auto x_b = ggml_view_1d(ctx, - transformer.ff_0_proj_b, - transformer.ff_0_proj_b->ne[0] / 2, - 0); // [in_channels * 4, in_channels] - auto gate_w = ggml_view_2d(ctx, - transformer.ff_0_proj_w, - transformer.ff_0_proj_w->ne[0], - transformer.ff_0_proj_w->ne[1] / 2, - transformer.ff_0_proj_w->nb[1], - transformer.ff_0_proj_w->nb[1] * transformer.ff_0_proj_w->ne[1] / 2); // [in_channels * 4, ] - auto gate_b = ggml_view_1d(ctx, - transformer.ff_0_proj_b, - transformer.ff_0_proj_b->ne[0] / 2, - transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ] - x = ggml_reshape_2d(ctx, x, c, w * h * n); - auto x_in = x; - x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4] - auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4] - - gate = ggml_gelu_inplace(ctx, gate); - - x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4] - // fc - x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels] - } - - x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels] - - // residual - x = ggml_add(ctx, x, r); - } - - x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w] - - // proj_out - x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w] - - x = ggml_add(ctx, x, x_in); - return x; - } -}; - // ldm.modules.diffusionmodules.openaimodel.UNetModel struct UNetModel : public GGMLModule { SDVersion version = VERSION_1_x; @@ -636,21 +180,21 @@ struct UNetModel : public GGMLModule { } size_t calculate_mem_size() { - double mem_size = 0; - mem_size += time_embed_dim * model_channels * ggml_type_sizef(wtype); // time_embed_0_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_0_b - mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // time_embed_2_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // time_embed_2_b + size_t mem_size = 0; + mem_size += ggml_row_size(wtype, time_embed_dim * model_channels); // time_embed_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // time_embed_0_b + mem_size += ggml_row_size(wtype, time_embed_dim * time_embed_dim); // time_embed_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // time_embed_2_b if (version == VERSION_XL) { - mem_size += time_embed_dim * adm_in_channels * ggml_type_sizef(wtype); // label_embed_0_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // label_embed_0_b - mem_size += time_embed_dim * time_embed_dim * ggml_type_sizef(wtype); // label_embed_2_w - mem_size += time_embed_dim * ggml_type_sizef(GGML_TYPE_F32); // label_embed_2_b + mem_size += ggml_row_size(wtype, time_embed_dim * adm_in_channels); // label_embed_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // label_embed_0_b + mem_size += ggml_row_size(wtype, time_embed_dim * time_embed_dim); // label_embed_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, time_embed_dim); // label_embed_2_b } - mem_size += model_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // input_block_0_w - mem_size += model_channels * ggml_type_sizef(GGML_TYPE_F32); // input_block_0_b + mem_size += ggml_row_size(GGML_TYPE_F16, model_channels * in_channels * 3 * 3); // input_block_0_w + mem_size += ggml_row_size(GGML_TYPE_F32, model_channels); // input_block_0_b // input_blocks int ds = 1; @@ -691,11 +235,11 @@ struct UNetModel : public GGMLModule { } // out - mem_size += 2 * model_channels * ggml_type_sizef(GGML_TYPE_F32); // out_0_w/b - mem_size += out_channels * model_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // out_2_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // out_2_b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, model_channels); // out_0_w/b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * model_channels * 3 * 3); // out_2_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // out_2_b - return static_cast(mem_size); + return mem_size; } size_t get_num_tensors() { @@ -892,6 +436,8 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, + float control_net_strength, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { // x: [N, in_channels, h, w] @@ -949,12 +495,24 @@ struct UNetModel : public GGMLModule { h = middle_block_1.forward(ctx0, h, context); // [N, 4*model_channels, h/8, w/8] h = middle_block_2.forward(ctx0, h, emb); // [N, 4*model_channels, h/8, w/8] + if(control.size() > 0) { + auto cs = ggml_scale_inplace(ctx0, control[control.size() - 1], control_net_strength); + h = ggml_add(ctx0, h, cs); // middle control + } + + int control_offset = control.size() - 2; // output_blocks for (int i = (int)len_mults - 1; i >= 0; i--) { for (int j = 0; j < num_res_blocks + 1; j++) { auto h_skip = hs.back(); hs.pop_back(); + if(control.size() > 0) { + auto cs = ggml_scale_inplace(ctx0, control[control_offset], control_net_strength); + h_skip = ggml_add(ctx0, h_skip, cs); // control net condition + control_offset--; + } + h = ggml_concat(ctx0, h, h_skip); h = output_res_blocks[i][j].forward(ctx0, h, emb); @@ -983,8 +541,10 @@ struct UNetModel : public GGMLModule { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* y = NULL, + float control_net_strength = 1.0) { // since we are using ggml-alloc, this buffer only needs enough space to hold the ggml_tensor and ggml_cgraph structs, but not the tensor data static size_t buf_size = ggml_tensor_overhead() * UNET_GRAPH_SIZE + ggml_graph_overhead(); static std::vector buf(buf_size); @@ -1006,6 +566,7 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* context_t = NULL; struct ggml_tensor* t_emb_t = NULL; struct ggml_tensor* y_t = NULL; + std::vector control_t; // it's performing a compute, check if backend isn't cpu if (!ggml_backend_is_cpu(backend)) { @@ -1049,7 +610,22 @@ struct UNetModel : public GGMLModule { y_t = y; } - struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, t_emb_t, y_t); + // offload all controls tensors to gpu + if(control.size() > 0 && !ggml_backend_is_cpu(backend) && control[0]->backend != GGML_BACKEND_GPU) { + for(int i = 0; i < control.size(); i++) { + ggml_tensor* cntl_t = ggml_dup_tensor(ctx0, control[i]); + control_t.push_back(cntl_t); + ggml_allocr_alloc(compute_allocr, cntl_t); + if(!ggml_allocr_is_measure(compute_allocr)) { + ggml_backend_tensor_copy(control[i], control_t[i]); + ggml_backend_synchronize(backend); + } + } + } else { + control_t = control; + } + + struct ggml_tensor* out = forward(ctx0, x_t, timesteps_t, context_t, control_t, control_net_strength, t_emb_t, y_t); ggml_build_forward_expand(gf, out); ggml_free(ctx0); @@ -1059,10 +635,12 @@ struct UNetModel : public GGMLModule { void alloc_compute_buffer(struct ggml_tensor* x, struct ggml_tensor* context, + std::vector control, struct ggml_tensor* t_emb = NULL, - struct ggml_tensor* y = NULL) { + struct ggml_tensor* y = NULL, + float control_net_strength = 1.0) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, NULL, context, t_emb, y); + return build_graph(x, NULL, context, control, t_emb, y, control_net_strength); }; GGMLModule::alloc_compute_buffer(get_graph); } @@ -1072,10 +650,12 @@ struct UNetModel : public GGMLModule { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + std::vector control, + float control_net_strength, struct ggml_tensor* t_emb = NULL, struct ggml_tensor* y = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, t_emb, y); + return build_graph(x, timesteps, context, control, t_emb, y, control_net_strength); }; GGMLModule::compute(get_graph, n_threads, work_latent); diff --git a/util.cpp b/util.cpp index 4057d13d..4445f6c5 100644 --- a/util.cpp +++ b/util.cpp @@ -1,5 +1,5 @@ #include "util.h" - +#include #include #include #include @@ -72,6 +72,20 @@ bool is_directory(const std::string& path) { return (attributes != INVALID_FILE_ATTRIBUTES && (attributes & FILE_ATTRIBUTE_DIRECTORY)); } +std::string get_full_path(const std::string& dir, const std::string& filename) { + std::string full_path = dir + "\\" + filename; + + WIN32_FIND_DATA find_file_data; + HANDLE hFind = FindFirstFile(full_path.c_str(), &find_file_data); + + if (hFind != INVALID_HANDLE_VALUE) { + FindClose(hFind); + return full_path; + } else { + return ""; + } +} + #else // Unix #include #include @@ -86,6 +100,25 @@ bool is_directory(const std::string& path) { return (stat(path.c_str(), &buffer) == 0 && S_ISDIR(buffer.st_mode)); } +std::string get_full_path(const std::string& dir, const std::string& filename) { + DIR* dp = opendir(dir.c_str()); + + if (dp != nullptr) { + struct dirent* entry; + + while ((entry = readdir(dp)) != nullptr) { + if (strcasecmp(entry->d_name, filename.c_str()) == 0) { + closedir(dp); + return dir + "/" + entry->d_name; + } + } + + closedir(dp); + } + + return ""; +} + #endif // get_num_physical_cores is copy from @@ -192,6 +225,24 @@ void pretty_progress(int step, int steps, float time) { } } +std::string ltrim(const std::string& s) { + auto it = std::find_if(s.begin(), s.end(), [](int ch) { + return !std::isspace(ch); + }); + return std::string(it, s.end()); +} + +std::string rtrim(const std::string& s) { + auto it = std::find_if(s.rbegin(), s.rend(), [](int ch) { + return !std::isspace(ch); + }); + return std::string(s.begin(), it.base()); +} + +std::string trim(const std::string& s) { + return rtrim(ltrim(s)); +} + static sd_log_cb_t sd_log_cb = NULL; void* sd_log_cb_data = NULL; diff --git a/util.h b/util.h index 3a611655..c1b035f1 100644 --- a/util.h +++ b/util.h @@ -15,6 +15,7 @@ void replace_all_chars(std::string& str, char target, char replacement); bool file_exists(const std::string& filename); bool is_directory(const std::string& path); +std::string get_full_path(const std::string& dir, const std::string& filename); std::u32string utf8_to_utf32(const std::string& utf8_str); std::string utf32_to_utf8(const std::u32string& utf32_str); @@ -28,6 +29,8 @@ void pretty_progress(int step, int steps, float time); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); +std::string trim(const std::string& s); + #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_WARN(format, ...) log_printf(SD_LOG_WARN, __FILE__, __LINE__, format, ##__VA_ARGS__) diff --git a/vae.hpp b/vae.hpp index 8a47a8ef..38af5408 100644 --- a/vae.hpp +++ b/vae.hpp @@ -32,14 +32,14 @@ struct ResnetBlock { size_t calculate_mem_size(ggml_type wtype) { double mem_size = 0; - mem_size += 2 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm1_w/b - mem_size += out_channels * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv1_w - mem_size += 4 * out_channels * ggml_type_sizef(GGML_TYPE_F32); // conv1_b/norm2_w/norm2_b/conv2_b - mem_size += out_channels * out_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv2_w + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm1_w/b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * in_channels * 3 * 3); // conv1_w + mem_size += 4 * ggml_row_size(GGML_TYPE_F32, out_channels); // conv1_b/norm2_w/norm2_b/conv2_b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * out_channels * 3 * 3); // conv2_w if (out_channels != in_channels) { - mem_size += out_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // nin_shortcut_w - mem_size += out_channels * ggml_type_sizef(GGML_TYPE_F32); // nin_shortcut_b + mem_size += ggml_row_size(GGML_TYPE_F16, out_channels * in_channels * 1 * 1); // nin_shortcut_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_channels); // nin_shortcut_b } return static_cast(mem_size); } @@ -120,8 +120,8 @@ struct AttnBlock { size_t calculate_mem_size(ggml_type wtype) { double mem_size = 0; - mem_size += 6 * in_channels * ggml_type_sizef(GGML_TYPE_F32); // norm_w/norm_b/q_b/k_v/v_b/proj_out_b - mem_size += 4 * in_channels * in_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // q_w/k_w/v_w/proj_out_w // object overhead + mem_size += 6 * ggml_row_size(GGML_TYPE_F32, in_channels); // norm_w/norm_b/q_b/k_v/v_b/proj_out_b + mem_size += 4 * ggml_row_size(GGML_TYPE_F16, in_channels * in_channels * 1 * 1); // q_w/k_w/v_w/proj_out_w // object overhead return static_cast(mem_size); } @@ -269,17 +269,17 @@ struct Encoder { } size_t calculate_mem_size(ggml_type wtype) { - double mem_size = 0; + size_t mem_size = 0; int len_mults = sizeof(ch_mult) / sizeof(int); int block_in = ch * ch_mult[len_mults - 1]; - mem_size += ch * in_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_in_w - mem_size += ch * ggml_type_sizef(GGML_TYPE_F32); // conv_in_b + mem_size += ggml_row_size(GGML_TYPE_F16, ch * in_channels * 3 * 3); // conv_in_w + mem_size += ggml_row_size(GGML_TYPE_F32, ch); // conv_in_b - mem_size += 2 * block_in * ggml_type_sizef(GGML_TYPE_F32); // norm_out_w/b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, block_in); // norm_out_w/b - mem_size += z_channels * 2 * block_in * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_out_w - mem_size += z_channels * 2 * ggml_type_sizef(GGML_TYPE_F32); // conv_out_b + mem_size += ggml_row_size(GGML_TYPE_F16, z_channels * 2 * block_in * 3 * 3); // conv_out_w + mem_size += ggml_row_size(GGML_TYPE_F32, z_channels * 2); // conv_out_b mem_size += mid.block_1.calculate_mem_size(wtype); mem_size += mid.attn_1.calculate_mem_size(wtype); @@ -294,7 +294,7 @@ struct Encoder { } } - return static_cast(mem_size); + return mem_size; } void init_params(struct ggml_context* ctx, ggml_allocr* alloc, ggml_type wtype) { @@ -436,13 +436,13 @@ struct Decoder { int len_mults = sizeof(ch_mult) / sizeof(int); int block_in = ch * ch_mult[len_mults - 1]; - mem_size += block_in * z_channels * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_in_w - mem_size += block_in * ggml_type_sizef(GGML_TYPE_F32); // conv_in_b + mem_size += ggml_row_size(GGML_TYPE_F16, block_in * z_channels * 3 * 3); // conv_in_w + mem_size += ggml_row_size(GGML_TYPE_F32, block_in); // conv_in_b - mem_size += 2 * (ch * ch_mult[0]) * ggml_type_sizef(GGML_TYPE_F32); // norm_out_w/b + mem_size += 2 * ggml_row_size(GGML_TYPE_F32, (ch * ch_mult[0])); // norm_out_w/b - mem_size += (ch * ch_mult[0]) * out_ch * 3 * 3 * ggml_type_sizef(GGML_TYPE_F16); // conv_out_w - mem_size += out_ch * ggml_type_sizef(GGML_TYPE_F32); // conv_out_b + mem_size += ggml_row_size(GGML_TYPE_F16, (ch * ch_mult[0]) * out_ch * 3 * 3); // conv_out_w + mem_size += ggml_row_size(GGML_TYPE_F32, out_ch); // conv_out_b mem_size += mid.block_1.calculate_mem_size(wtype); mem_size += mid.attn_1.calculate_mem_size(wtype); @@ -606,19 +606,19 @@ struct AutoEncoderKL : public GGMLModule { } size_t calculate_mem_size() { - double mem_size = 0; + size_t mem_size = 0; if (!decode_only) { - mem_size += 2 * embed_dim * 2 * dd_config.z_channels * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // quant_conv_w - mem_size += 2 * embed_dim * ggml_type_sizef(GGML_TYPE_F32); // quant_conv_b + mem_size += ggml_row_size(GGML_TYPE_F16, 2 * embed_dim * 2 * dd_config.z_channels * 1 * 1); // quant_conv_w + mem_size += ggml_row_size(GGML_TYPE_F32, 2 * embed_dim); // quant_conv_b mem_size += encoder.calculate_mem_size(wtype); } - mem_size += dd_config.z_channels * embed_dim * 1 * 1 * ggml_type_sizef(GGML_TYPE_F16); // post_quant_conv_w - mem_size += dd_config.z_channels * ggml_type_sizef(GGML_TYPE_F32); // post_quant_conv_b + mem_size += ggml_row_size(GGML_TYPE_F16, dd_config.z_channels * embed_dim * 1 * 1); // post_quant_conv_w + mem_size += ggml_row_size(GGML_TYPE_F32, dd_config.z_channels); // post_quant_conv_b mem_size += decoder.calculate_mem_size(wtype); - return static_cast(mem_size); + return mem_size; } size_t get_num_tensors() {