From 69efe3ce2b269bcc12ce276bf7087d6c6d272908 Mon Sep 17 00:00:00 2001 From: leejet Date: Sat, 9 Dec 2023 17:35:10 +0800 Subject: [PATCH] chore: make code cleaner --- model.cpp | 3 +- model.h | 2 +- stable-diffusion.cpp | 388 ++++++++++++++++--------------------------- 3 files changed, 145 insertions(+), 248 deletions(-) diff --git a/model.cpp b/model.cpp index 3adbec9f..cfa58908 100644 --- a/model.cpp +++ b/model.cpp @@ -1102,6 +1102,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, reader.tensor_storage.file_index = file_index; reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); + // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset reader = PickleTensorReader(); } @@ -1139,7 +1140,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s size_t pkl_size; zip_entry_read(zip, &pkl_data, &pkl_size); - LOG_DEBUG("%lld", pkl_size); + // LOG_DEBUG("%lld", pkl_size); parse_data_pkl((uint8_t*)pkl_data, pkl_size, zip, dir, file_index, prefix); diff --git a/model.h b/model.h index 6f27cdbf..d3f09d91 100644 --- a/model.h +++ b/model.h @@ -7,8 +7,8 @@ #include #include -#include "ggml/ggml.h" #include "ggml/ggml-backend.h" +#include "ggml/ggml.h" #include "json.hpp" #include "zip.h" diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 8c66f550..5bd6990b 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -398,6 +398,64 @@ struct ggml_tensor* ggml_group_norm_32(struct ggml_context* ctx, return ggml_group_norm(ctx, a, 32); } +struct ggml_tensor* ggml_nn_linear(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* w, + struct ggml_tensor* b) { + x = ggml_mul_mat(ctx, w, x); + x = ggml_add(ctx, x, b); + return x; +} + +// w: [OC,IC, KH, KW] +// x: [N, IC, IH, IW] +// b: [OC,] +// result: [N, OC, OH, OW] +struct ggml_tensor* ggml_nn_conv_2d(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* w, + struct ggml_tensor* b, + int s0 = 1, + int s1 = 1, + int p0 = 0, + int p1 = 0, + int d0 = 1, + int d1 = 1) { + x = ggml_conv_2d(ctx, w, x, s0, s1, p0, p1, d0, d1); + if (b != NULL) { + b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); + x = ggml_add(ctx, x, b); + } + return x; +} + +struct ggml_tensor* ggml_nn_layer_norm(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* w, + struct ggml_tensor* b, + float eps = EPS) { + x = ggml_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + x = ggml_add(ctx, x, b); + return x; +} + +struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* w, + struct ggml_tensor* b, + int num_groups = 32) { + if (x->n_dims == 4) { + w = ggml_reshape_4d(ctx, w, 1, 1, w->ne[0], 1); + b = ggml_reshape_4d(ctx, b, 1, 1, b->ne[0], 1); + } + + x = ggml_group_norm(ctx, x, num_groups); + x = ggml_mul(ctx, x, w); + x = ggml_add(ctx, x, b); + return x; +} + std::pair, std::string> extract_and_remove_lora(std::string text) { std::regex re("]+)>"); std::smatch matches; @@ -749,30 +807,21 @@ struct ResidualAttentionBlock { struct ggml_tensor* r = x; // layer norm 1 - { - x = ggml_norm(ctx, x, EPS); - x = ggml_add(ctx, - ggml_mul(ctx, x, ln1_w), - ln1_b); - } + x = ggml_nn_layer_norm(ctx, x, ln1_w, ln1_b); // self-attention { - struct ggml_tensor* q = ggml_add(ctx, - ggml_mul_mat(ctx, q_w, x), - q_b); + struct ggml_tensor* q = ggml_nn_linear(ctx, x, q_w, q_b); q = ggml_scale_inplace(ctx, q, attn_scale); q = ggml_reshape_4d(ctx, q, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model] q = ggml_cont(ctx, ggml_permute(ctx, q, 0, 2, 1, 3)); // [N, n_head, n_token, d_model] q = ggml_reshape_3d(ctx, q, d_model, n_token, n_head * N); // [N * n_head, n_token, d_model] - struct ggml_tensor* k = ggml_add(ctx, - ggml_mul_mat(ctx, k_w, x), k_b); + struct ggml_tensor* k = ggml_nn_linear(ctx, x, k_w, k_b); k = ggml_reshape_4d(ctx, k, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model] k = ggml_cont(ctx, ggml_permute(ctx, k, 0, 2, 1, 3)); // [N, n_head, n_token, d_model] k = ggml_reshape_3d(ctx, k, d_model, n_token, n_head); // [N * n_head, n_token, d_model] - struct ggml_tensor* v = ggml_add(ctx, - ggml_mul_mat(ctx, v_w, x), v_b); + struct ggml_tensor* v = ggml_nn_linear(ctx, x, v_w, v_b); v = ggml_reshape_4d(ctx, v, d_model, n_head, n_token, N); // [N, n_token, n_head, d_model] v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_model, n_token] v = ggml_reshape_3d(ctx, v, n_token, d_model, n_head * N); // [N * n_head, d_model, n_token] @@ -790,24 +839,17 @@ struct ResidualAttentionBlock { } // attention output - x = ggml_mul_mat(ctx, out_w, x); - x = ggml_add(ctx, x, out_b); + x = ggml_nn_linear(ctx, x, out_w, out_b); // residual x = ggml_add(ctx, x, r); r = x; // layer norm 2 - { - x = ggml_norm(ctx, x, EPS); - - x = ggml_add(ctx, ggml_mul(ctx, x, ln2_w), - ln2_b); - } + x = ggml_nn_layer_norm(ctx, x, ln2_w, ln2_b); // mlp - x = ggml_mul_mat(ctx, fc1_w, x); - x = ggml_add(ctx, x, fc1_b); + x = ggml_nn_linear(ctx, x, fc1_w, fc1_b); if (hidden_size == 1024) { // SD 2.x x = ggml_gelu_inplace(ctx, x); @@ -815,8 +857,7 @@ struct ResidualAttentionBlock { x = ggml_gelu_quick_inplace(ctx, x); } - x = ggml_mul_mat(ctx, fc2_w, x); - x = ggml_add(ctx, x, fc2_b); + x = ggml_nn_linear(ctx, x, fc2_w, fc2_b); // residual 2 x = ggml_add(ctx, x, r); @@ -1004,12 +1045,7 @@ struct CLIPTextModel { } // final layer norm - { - x = ggml_norm(ctx0, x, EPS); - - x = ggml_add(ctx0, ggml_mul(ctx0, x, final_ln_w), - final_ln_b); - } + x = ggml_nn_layer_norm(ctx0, x, final_ln_w, final_ln_b); return x; // [N, n_token, hidden_size] } @@ -1263,48 +1299,29 @@ struct ResBlock { // emb: [N, emb_channels] // in_layers - // group norm 32 - auto h = ggml_group_norm_32(ctx, x); - h = ggml_add(ctx, - ggml_mul(ctx, - h, - ggml_reshape_4d(ctx, in_layer_0_w, 1, 1, in_layer_0_w->ne[0], 1)), - ggml_reshape_4d(ctx, in_layer_0_b, 1, 1, in_layer_0_b->ne[0], 1)); - // silu - h = ggml_silu_inplace(ctx, h); - // conv2d - h = ggml_conv_2d(ctx, in_layer_2_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, - ggml_reshape_4d(ctx, in_layer_2_b, 1, 1, in_layer_2_b->ne[0], 1)); // [N, out_channels, h, w] + auto h = ggml_nn_group_norm(ctx, x, in_layer_0_w, in_layer_0_b); + h = ggml_silu_inplace(ctx, h); + h = ggml_nn_conv_2d(ctx, h, in_layer_2_w, in_layer_2_b, 1, 1, 1, 1); // [N, out_channels, h, w] // emb_layers auto emb_out = ggml_silu(ctx, emb); - emb_out = ggml_mul_mat(ctx, emb_layer_1_w, emb_out); - emb_out = ggml_add(ctx, emb_out, emb_layer_1_b); // [N, out_channels] + emb_out = ggml_nn_linear(ctx, emb_out, emb_layer_1_w, emb_layer_1_b); // [N, out_channels] emb_out = ggml_reshape_4d(ctx, emb_out, 1, 1, emb_out->ne[0], emb_out->ne[1]); // [N, out_channels, 1, 1] // out_layers h = ggml_add(ctx, h, emb_out); - // group norm 32 - h = ggml_group_norm_inplace(ctx, h, 32); - h = ggml_add(ctx, - ggml_mul(ctx, h, ggml_reshape_4d(ctx, out_layer_0_w, 1, 1, out_layer_0_w->ne[0], 1)), - ggml_reshape_4d(ctx, out_layer_0_b, 1, 1, out_layer_0_b->ne[0], 1)); - // silu + h = ggml_nn_group_norm(ctx, h, out_layer_0_w, out_layer_0_b); h = ggml_silu_inplace(ctx, h); + // dropout, skip for inference - // conv2d - h = ggml_conv_2d(ctx, out_layer_3_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, out_layer_3_b, 1, 1, out_layer_3_b->ne[0], 1)); // [N, out_channels, h, w + + h = ggml_nn_conv_2d(ctx, h, out_layer_3_w, out_layer_3_b, 1, 1, 1, 1); // [N, out_channels, h, w] // skip connection if (out_channels != channels) { - x = ggml_conv_2d(ctx, skip_w, x, 1, 1, 0, 0, 1, 1); - x = ggml_add(ctx, - x, ggml_reshape_4d(ctx, skip_b, 1, 1, skip_b->ne[0], 1)); // [N, out_channels, h, w] + x = ggml_nn_conv_2d(ctx, x, skip_w, skip_b); // [N, out_channels, h, w] } + h = ggml_add(ctx, h, x); return h; // [N, out_channels, h, w] } @@ -1479,15 +1496,9 @@ struct SpatialTransformer { // x: [N, in_channels, h, w] // context: [N, max_position, hidden_size(aka context_dim)] auto x_in = x; - // group norm 32 - x = ggml_group_norm_32(ctx, x); - x = ggml_add(ctx, - ggml_mul(ctx, x, ggml_reshape_4d(ctx, norm_w, 1, 1, norm_w->ne[0], 1)), - ggml_reshape_4d(ctx, norm_b, 1, 1, norm_b->ne[0], 1)); + x = ggml_nn_group_norm(ctx, x, norm_w, norm_b); // proj_in - x = ggml_conv_2d(ctx, proj_in_w, x, 1, 1, 0, 0, 1, 1); - x = ggml_add(ctx, - x, ggml_reshape_4d(ctx, proj_in_b, 1, 1, proj_in_b->ne[0], 1)); // [N, in_channels, h, w] + x = ggml_nn_conv_2d(ctx, x, proj_in_w, proj_in_b); // [N, in_channels, h, w] // transformer const int64_t n = x->ne[3]; @@ -1500,13 +1511,8 @@ struct SpatialTransformer { { auto r = x; // layer norm 1 - { - x = ggml_reshape_2d(ctx, x, c, w * h * n); - x = ggml_norm(ctx, x, EPS); - x = ggml_add(ctx, - ggml_mul(ctx, x, transformer.norm1_w), - transformer.norm1_b); - } + x = ggml_reshape_2d(ctx, x, c, w * h * n); + x = ggml_nn_layer_norm(ctx, x, transformer.norm1_w, transformer.norm1_b); // self-attention { @@ -1544,7 +1550,7 @@ struct SpatialTransformer { // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); - x = ggml_add(ctx, ggml_mul_mat(ctx, transformer.attn1_out_w, x), transformer.attn1_out_b); + x = ggml_nn_linear(ctx, x, transformer.attn1_out_w, transformer.attn1_out_b); x = ggml_reshape_4d(ctx, x, c, w, h, n); } @@ -1553,11 +1559,7 @@ struct SpatialTransformer { r = x; // layer norm 2 - { - x = ggml_norm(ctx, x, EPS); - x = ggml_add(ctx, - ggml_mul(ctx, x, transformer.norm2_w), transformer.norm2_b); - } + x = ggml_nn_layer_norm(ctx, x, transformer.norm2_w, transformer.norm2_b); // cross-attention { @@ -1595,7 +1597,7 @@ struct SpatialTransformer { // x = ggml_cpy(ctx, kqv, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, d_head * n_head, h * w * n)); // [N * h * w, in_channels] x = ggml_reshape_2d(ctx, kqv, d_head * n_head, h * w * n); // [N * h * w, in_channels] - x = ggml_add(ctx, ggml_mul_mat(ctx, transformer.attn2_out_w, x), transformer.attn2_out_b); + x = ggml_nn_linear(ctx, x, transformer.attn2_out_w, transformer.attn2_out_b); x = ggml_reshape_4d(ctx, x, c, w, h, n); } @@ -1604,13 +1606,8 @@ struct SpatialTransformer { r = x; // layer norm 3 - { - x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] - x = ggml_norm(ctx, x, EPS); - x = ggml_add(ctx, - ggml_mul(ctx, x, transformer.norm3_w), - transformer.norm3_b); - } + x = ggml_reshape_2d(ctx, x, c, h * w * n); // [N * h * w, in_channels] + x = ggml_nn_layer_norm(ctx, x, transformer.norm3_w, transformer.norm3_b); // ff { @@ -1637,17 +1634,14 @@ struct SpatialTransformer { transformer.ff_0_proj_b->nb[0] * transformer.ff_0_proj_b->ne[0] / 2); // [in_channels * 4, ] x = ggml_reshape_2d(ctx, x, c, w * h * n); auto x_in = x; - x = ggml_mul_mat(ctx, x_w, x_in); // [N * h * w, in_channels * 4] - x = ggml_add(ctx, x, x_b); - auto gate = ggml_mul_mat(ctx, gate_w, x_in); // [N * h * w, in_channels * 4] - gate = ggml_add(ctx, gate, gate_b); + x = ggml_nn_linear(ctx, x_in, x_w, x_b); // [N * h * w, in_channels * 4] + auto gate = ggml_nn_linear(ctx, x_in, gate_w, gate_b); // [N * h * w, in_channels * 4] gate = ggml_gelu_inplace(ctx, gate); x = ggml_mul(ctx, x, gate); // [N * h * w, in_channels * 4] // fc - x = ggml_mul_mat(ctx, transformer.ff_2_w, x); // [N * h * w, in_channels] - x = ggml_add(ctx, x, transformer.ff_2_b); + x = ggml_nn_linear(ctx, x, transformer.ff_2_w, transformer.ff_2_b); // [N * h * w, in_channels] } x = ggml_reshape_4d(ctx, x, c, w, h, n); // [N, h, w, in_channels] @@ -1655,12 +1649,11 @@ struct SpatialTransformer { // residual x = ggml_add(ctx, x, r); } - x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // // [N, in_channels, h, w] + x = ggml_cont(ctx, ggml_permute(ctx, x, 2, 0, 1, 3)); // [N, in_channels, h, w] // proj_out - x = ggml_conv_2d(ctx, proj_out_w, x, 1, 1, 0, 0, 1, 1); - x = ggml_add(ctx, - x, ggml_reshape_4d(ctx, proj_out_b, 1, 1, proj_out_b->ne[0], 1)); // [N, in_channels, h, w] + x = ggml_nn_conv_2d(ctx, x, proj_out_w, proj_out_b); // [N, in_channels, h, w] + x = ggml_add(ctx, x, x_in); return x; } @@ -1701,17 +1694,14 @@ struct DownSample { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] - struct ggml_tensor* c = nullptr; + struct ggml_tensor* c = NULL; if (vae_downsample) { c = ggml_pad(ctx, x, 1, 1, 0, 0); - c = ggml_conv_2d(ctx, op_w, c, 2, 2, 0, 0, 1, 1); + c = ggml_nn_conv_2d(ctx, c, op_w, op_b, 2, 2, 0, 0); } else { - c = ggml_conv_2d(ctx, op_w, x, 2, 2, 1, 1, 1, 1); + c = ggml_nn_conv_2d(ctx, x, op_w, op_b, 2, 2, 1, 1); } - c = ggml_add(ctx, - c, - ggml_reshape_4d(ctx, op_b, 1, 1, op_b->ne[0], 1)); // [N, out_channels, h/2, w/2] - return c; + return c; // [N, out_channels, h/2, w/2] } }; @@ -1743,11 +1733,8 @@ struct UpSample { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, channels, h, w] - x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2] - x = ggml_conv_2d(ctx, conv_w, x, 1, 1, 1, 1, 1, 1); - x = ggml_add(ctx, - x, - ggml_reshape_4d(ctx, conv_b, 1, 1, conv_b->ne[0], 1)); // [N, out_channels, h*2, w*2] + x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2] + x = ggml_nn_conv_2d(ctx, x, conv_w, conv_b, 1, 1, 1, 1); // [N, out_channels, h*2, w*2] return x; } }; @@ -2212,14 +2199,10 @@ struct UNetModel { } // time_embed = nn.Sequential + auto emb = ggml_nn_linear(ctx0, t_emb, time_embed_0_w, time_embed_0_b); + emb = ggml_silu_inplace(ctx0, emb); // Linear - auto emb = ggml_mul_mat(ctx0, time_embed_0_w, t_emb); - emb = ggml_add(ctx0, emb, time_embed_0_b); - // nn.SiLU() - emb = ggml_silu_inplace(ctx0, emb); - // Linear - emb = ggml_mul_mat(ctx0, time_embed_2_w, emb); - emb = ggml_add(ctx0, emb, time_embed_2_b); // [N, time_embed_dim] + emb = ggml_nn_linear(ctx0, emb, time_embed_2_w, time_embed_2_b); // [N, time_embed_dim] // SDXL // label_emd = nn.Sequential @@ -2227,13 +2210,9 @@ struct UNetModel { // param y: an [N] Tensor of labels, if class-conditional. (clip g) // if(y != NULL) { - // auto y_emb = ggml_mul_mat(ctx, label_embed_0_w, y); - // y_emb = ggml_add(ctx, y_emb, label_embed_0_b); - // // nn.SiLU() + // auto y_emb = ggml_nn_linear(ctx, y, label_embed_0_w, label_embed_0_b); // y_emb = ggml_silu_inplace(ctx, y_emb); - // // Linear - // y_emb = ggml_mul_mat(ctx, label_embed_2_w, y_emb); - // y_emb = ggml_add(ctx, y_emb, label_embed_2_b); + // y_emb = ggml_nn_linear(ctx, y_emb, label_embed_2_w, label_embed_2_b); // emb = ggml_add(ctx, emb, y_emb); // } @@ -2241,11 +2220,8 @@ struct UNetModel { std::vector hs; // input block 0 - struct ggml_tensor* h = ggml_conv_2d(ctx0, input_block_0_w, x, 1, 1, 1, 1, 1, 1); // [N, model_channels, h, w] + struct ggml_tensor* h = ggml_nn_conv_2d(ctx0, x, input_block_0_w, input_block_0_b, 1, 1, 1, 1); // [N, model_channels, h, w] - h = ggml_add(ctx0, - h, - ggml_reshape_4d(ctx0, input_block_0_b, 1, 1, input_block_0_b->ne[0], 1)); // [N, model_channels, h, w] ggml_set_name(h, "bench-start"); hs.push_back(h); // input block 1-11 @@ -2295,18 +2271,11 @@ struct UNetModel { } // out - // group norm 32 - h = ggml_group_norm_32(ctx0, h); - h = ggml_add(ctx0, - ggml_mul(ctx0, h, ggml_reshape_4d(ctx0, out_0_w, 1, 1, out_0_w->ne[0], 1)), - ggml_reshape_4d(ctx0, out_0_b, 1, 1, out_0_b->ne[0], 1)); - // silu + h = ggml_nn_group_norm(ctx0, h, out_0_w, out_0_b); h = ggml_silu_inplace(ctx0, h); // conv2d - h = ggml_conv_2d(ctx0, out_2_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx0, - h, ggml_reshape_4d(ctx0, out_2_b, 1, 1, out_2_b->ne[0], 1)); // [N, out_channels, h, w] + h = ggml_nn_conv_2d(ctx0, h, out_2_w, out_2_b, 1, 1, 1, 1); // [N, out_channels, h, w] ggml_set_name(h, "bench-end"); return h; } @@ -2503,38 +2472,19 @@ struct ResnetBlock { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) { // z: [N, in_channels, h, w] - // group norm 32 - auto h = ggml_group_norm_32(ctx, z); - h = ggml_mul(ctx, - h, ggml_reshape_4d(ctx, norm1_w, 1, 1, norm1_w->ne[0], 1)); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, norm1_b, 1, 1, norm1_b->ne[0], 1)); - // silu - h = ggml_silu_inplace(ctx, h); - // conv2d - h = ggml_conv_2d(ctx, conv1_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, conv1_b, 1, 1, conv1_b->ne[0], 1)); // [N, out_channels, h, w] - - // group norm 32 - h = ggml_group_norm_32(ctx, h); - h = ggml_add(ctx, - ggml_mul(ctx, h, ggml_reshape_4d(ctx, norm2_w, 1, 1, norm2_w->ne[0], 1)), - ggml_reshape_4d(ctx, norm2_b, 1, 1, norm2_b->ne[0], 1)); - // silu - h = ggml_silu_inplace(ctx, h); + auto h = ggml_nn_group_norm(ctx, z, norm1_w, norm1_b); + h = ggml_silu_inplace(ctx, h); + h = ggml_nn_conv_2d(ctx, h, conv1_w, conv1_b, 1, 1, 1, 1); // [N, out_channels, h, w] + h = ggml_nn_group_norm(ctx, h, norm2_w, norm2_b); + h = ggml_silu_inplace(ctx, h); // dropout, skip for inference - // conv2d - h = ggml_conv_2d(ctx, conv2_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, conv2_b, 1, 1, conv2_b->ne[0], 1)); // [N, out_channels, h, w + h = ggml_nn_conv_2d(ctx, h, conv2_w, conv2_b, 1, 1, 1, 1); // [N, out_channels, h, w] // skip connection if (out_channels != in_channels) { - z = ggml_conv_2d(ctx, nin_shortcut_w, z, 1, 1, 0, 0, 1, 1); - z = ggml_add(ctx, - z, ggml_reshape_4d(ctx, nin_shortcut_b, 1, 1, nin_shortcut_b->ne[0], 1)); // [N, out_channels, h, w] + z = ggml_nn_conv_2d(ctx, z, nin_shortcut_w, nin_shortcut_b); // [N, out_channels, h, w] } + h = ggml_add(ctx, h, z); return h; // [N, out_channels, h, w] } @@ -2604,30 +2554,16 @@ struct AttnBlock { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { // x: [N, in_channels, h, w] - // group norm 32 - auto h_ = ggml_group_norm_32(ctx, x); - h_ = ggml_add(ctx, - ggml_mul(ctx, h_, ggml_reshape_4d(ctx, norm_w, 1, 1, norm_w->ne[0], 1)), - ggml_reshape_4d(ctx, norm_b, 1, 1, norm_b->ne[0], 1)); + auto h_ = ggml_nn_group_norm(ctx, x, norm_w, norm_b); const int64_t n = h_->ne[3]; const int64_t c = h_->ne[2]; const int64_t h = h_->ne[1]; const int64_t w = h_->ne[0]; - // q - auto q = ggml_conv_2d(ctx, q_w, h_, 1, 1, 0, 0, 1, 1); - q = ggml_add(ctx, - q, ggml_reshape_4d(ctx, q_b, 1, 1, q_b->ne[0], 1)); // [N, in_channels, h, w] - // k - auto k = ggml_conv_2d(ctx, k_w, h_, 1, 1, 0, 0, 1, 1); - k = ggml_add(ctx, - k, ggml_reshape_4d(ctx, k_b, 1, 1, k_b->ne[0], 1)); // [N, in_channels, h, w] - - // v - auto v = ggml_conv_2d(ctx, v_w, h_, 1, 1, 0, 0, 1, 1); - v = ggml_add(ctx, - v, ggml_reshape_4d(ctx, v_b, 1, 1, v_b->ne[0], 1)); // [N, in_channels, h, w] + auto q = ggml_nn_conv_2d(ctx, h_, q_w, q_b); // [N, in_channels, h, w] + auto k = ggml_nn_conv_2d(ctx, h_, k_w, k_b); // [N, in_channels, h, w] + auto v = ggml_nn_conv_2d(ctx, h_, v_w, v_b); // [N, in_channels, h, w] q = ggml_cont(ctx, ggml_permute(ctx, q, 1, 2, 0, 3)); // [N, h, w, in_channels] q = ggml_reshape_3d(ctx, q, c, h * w, n); // [N, h * w, in_channels] @@ -2645,9 +2581,8 @@ struct AttnBlock { h_ = ggml_reshape_4d(ctx, h_, w, h, c, n); // [N, in_channels, h, w] // proj_out - h_ = ggml_conv_2d(ctx, proj_out_w, h_, 1, 1, 0, 0, 1, 1); - h_ = ggml_add(ctx, - h_, ggml_reshape_4d(ctx, proj_out_b, 1, 1, proj_out_b->ne[0], 1)); // [N, in_channels, h, w] + h_ = ggml_nn_conv_2d(ctx, h_, proj_out_w, proj_out_b); // [N, in_channels, h, w] + h_ = ggml_add(ctx, h_, x); return h_; } @@ -2814,9 +2749,7 @@ struct Encoder { // x: [N, in_channels, h, w] // conv_in - auto h = ggml_conv_2d(ctx, conv_in_w, x, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, conv_in_b, 1, 1, conv_in_b->ne[0], 1)); // [N, ch, h, w] + auto h = ggml_nn_conv_2d(ctx, x, conv_in_w, conv_in_b, 1, 1, 1, 1); // [N, ch, h, w] ggml_set_name(h, "b-start"); int len_mults = sizeof(ch_mult) / sizeof(int); for (int i = 0; i < len_mults; i++) { @@ -2832,20 +2765,11 @@ struct Encoder { h = mid.attn_1.forward(ctx, h); h = mid.block_2.forward(ctx, h); // [N, block_in, h, w] - // group norm 32 - h = ggml_group_norm_32(ctx, h); - h = ggml_add(ctx, - ggml_mul(ctx, h, ggml_reshape_4d(ctx, norm_out_w, 1, 1, norm_out_w->ne[0], 1)), - ggml_reshape_4d(ctx, norm_out_b, 1, 1, norm_out_b->ne[0], 1)); - - // silu - // silu + h = ggml_nn_group_norm(ctx, h, norm_out_w, norm_out_b); h = ggml_silu_inplace(ctx, h); // conv_out - h = ggml_conv_2d(ctx, conv_out_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, conv_out_b, 1, 1, conv_out_b->ne[0], 1)); // [N, z_channels*2, h, w] + h = ggml_nn_conv_2d(ctx, h, conv_out_w, conv_out_b, 1, 1, 1, 1); // [N, z_channels*2, h, w] return h; } @@ -3007,9 +2931,7 @@ struct Decoder { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* z) { // z: [N, z_channels, h, w] // conv_in - auto h = ggml_conv_2d(ctx, conv_in_w, z, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, conv_in_b, 1, 1, conv_in_b->ne[0], 1)); // [N, block_in, h, w] + auto h = ggml_nn_conv_2d(ctx, z, conv_in_w, conv_in_b, 1, 1, 1, 1); // [N, block_in, h, w] h = mid.block_1.forward(ctx, h); h = mid.attn_1.forward(ctx, h); @@ -3026,19 +2948,11 @@ struct Decoder { } // group norm 32 - h = ggml_group_norm_32(ctx, h); - h = ggml_add(ctx, - ggml_mul(ctx, h, ggml_reshape_4d(ctx, norm_out_w, 1, 1, norm_out_w->ne[0], 1)), - ggml_reshape_4d(ctx, norm_out_b, 1, 1, norm_out_b->ne[0], 1)); - - // silu - // silu + h = ggml_nn_group_norm(ctx, h, norm_out_w, norm_out_b); h = ggml_silu_inplace(ctx, h); // conv_out - h = ggml_conv_2d(ctx, conv_out_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, - h, ggml_reshape_4d(ctx, conv_out_b, 1, 1, conv_out_b->ne[0], 1)); // [N, out_ch, h, w] + h = ggml_nn_conv_2d(ctx, h, conv_out_w, conv_out_b, 1, 1, 1, 1); // [N, out_ch, h, w] return h; } }; @@ -3187,9 +3101,7 @@ struct AutoEncoderKL { struct ggml_tensor* decode(struct ggml_context* ctx0, struct ggml_tensor* z) { // z: [N, z_channels, h, w] // post_quant_conv - auto h = ggml_conv_2d(ctx0, post_quant_conv_w, z, 1, 1, 0, 0, 1, 1); - h = ggml_add(ctx0, - h, ggml_reshape_4d(ctx0, post_quant_conv_b, 1, 1, post_quant_conv_b->ne[0], 1)); // [N, z_channels, h, w] + auto h = ggml_nn_conv_2d(ctx0, z, post_quant_conv_w, post_quant_conv_b); // [N, z_channels, h, w] ggml_set_name(h, "bench-start"); h = decoder.forward(ctx0, h); ggml_set_name(h, "bench-end"); @@ -3200,10 +3112,7 @@ struct AutoEncoderKL { // x: [N, in_channels, h, w] auto h = encoder.forward(ctx0, x); // [N, 2*z_channels, h/8, w/8] // quant_conv - h = ggml_conv_2d(ctx0, quant_conv_w, h, 1, 1, 0, 0, 1, 1); - h = ggml_add(ctx0, - h, - ggml_reshape_4d(ctx0, quant_conv_b, 1, 1, quant_conv_b->ne[0], 1)); // [N, 2*embed_dim, h/8, w/8] + h = ggml_nn_conv_2d(ctx0, h, quant_conv_w, quant_conv_b); // [N, 2*embed_dim, h/8, w/8] ggml_set_name(h, "b-end"); return h; } @@ -3367,25 +3276,16 @@ struct TAEBlock { ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) { // conv(n_in, n_out) ggml_tensor* h; - h = ggml_conv_2d(ctx, conv_0_w, x, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_0_b, 1, 1, conv_0_b->ne[0], 1)); - - // relu + h = ggml_nn_conv_2d(ctx, x, conv_0_w, conv_0_b, 1, 1, 1, 1); h = ggml_relu_inplace(ctx, h); - - h = ggml_conv_2d(ctx, conv_1_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_1_b, 1, 1, conv_1_b->ne[0], 1)); - - // relu + h = ggml_nn_conv_2d(ctx, h, conv_1_w, conv_1_b, 1, 1, 1, 1); h = ggml_relu_inplace(ctx, h); - - h = ggml_conv_2d(ctx, conv_2_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_2_b, 1, 1, conv_2_b->ne[0], 1)); + h = ggml_nn_conv_2d(ctx, h, conv_2_w, conv_2_b, 1, 1, 1, 1); // skip connection if (in_channels != out_channels) { // skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() - x = ggml_conv_2d(ctx, conv_skip_w, x, 1, 1, 1, 1, 1, 1); + x = ggml_nn_conv_2d(ctx, x, conv_skip_w, NULL, 1, 1, 1, 1); } h = ggml_add(ctx, h, x); @@ -3514,14 +3414,13 @@ struct TinyEncoder { ggml_tensor* forward(ggml_context* ctx, ggml_tensor* x) { // conv(3, 64) - auto z = ggml_conv_2d(ctx, conv_input_w, x, 1, 1, 1, 1, 1, 1); - z = ggml_add(ctx, z, ggml_reshape_4d(ctx, conv_input_b, 1, 1, conv_input_b->ne[0], 1)); + auto z = ggml_nn_conv_2d(ctx, x, conv_input_w, conv_input_b, 1, 1, 1, 1); // Block(64, 64) z = initial_block.forward(ctx, z); // conv(64, 64, stride=2, bias=False) - z = ggml_conv_2d(ctx, conv_1_w, z, 2, 2, 1, 1, 1, 1); + z = ggml_nn_conv_2d(ctx, z, conv_1_w, NULL, 2, 2, 1, 1); // Block(64, 64), Block(64, 64), Block(64, 64) for (int i = 0; i < num_blocks; i++) { @@ -3529,7 +3428,7 @@ struct TinyEncoder { } // conv(64, 64, stride=2, bias=False) - z = ggml_conv_2d(ctx, conv_2_w, z, 2, 2, 1, 1, 1, 1); + z = ggml_nn_conv_2d(ctx, z, conv_2_w, NULL, 2, 2, 1, 1); // Block(64, 64), Block(64, 64), Block(64, 64) for (int i = 0; i < num_blocks; i++) { @@ -3537,7 +3436,7 @@ struct TinyEncoder { } // conv(64, 64, stride=2, bias=False) - z = ggml_conv_2d(ctx, conv_3_w, z, 2, 2, 1, 1, 1, 1); + z = ggml_nn_conv_2d(ctx, z, conv_3_w, NULL, 2, 2, 1, 1); // Block(64, 64), Block(64, 64), Block(64, 64) for (int i = 0; i < num_blocks; i++) { @@ -3545,8 +3444,7 @@ struct TinyEncoder { } // conv(64, 4) - z = ggml_conv_2d(ctx, conv_final_w, z, 1, 1, 1, 1, 1, 1); - z = ggml_add(ctx, z, ggml_reshape_4d(ctx, conv_final_b, 1, 1, conv_final_b->ne[0], 1)); + z = ggml_nn_conv_2d(ctx, z, conv_final_w, conv_final_b, 1, 1, 1, 1); return z; } }; @@ -3694,8 +3592,7 @@ struct TinyDecoder { h = ggml_scale(ctx, h, in_scale_3); // conv(4, 64) - h = ggml_conv_2d(ctx, conv_input_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_input_b, 1, 1, conv_input_b->ne[0], 1)); + h = ggml_nn_conv_2d(ctx, h, conv_input_w, conv_input_b, 1, 1, 1, 1); // nn.ReLU() h = ggml_relu_inplace(ctx, h); @@ -3709,7 +3606,7 @@ struct TinyDecoder { h = ggml_upscale(ctx, h, 2); // conv(64, 64, bias=False) - h = ggml_conv_2d(ctx, conv_1_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_nn_conv_2d(ctx, h, conv_1_w, NULL, 1, 1, 1, 1); // Block(64, 64), Block(64, 64), Block(64, 64) for (int i = 0; i < num_blocks; i++) { @@ -3720,7 +3617,7 @@ struct TinyDecoder { h = ggml_upscale(ctx, h, 2); // conv(64, 64, bias=False) - h = ggml_conv_2d(ctx, conv_2_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_nn_conv_2d(ctx, h, conv_2_w, NULL, 1, 1, 1, 1); // Block(64, 64), Block(64, 64), Block(64, 64) for (int i = 0; i < num_blocks; i++) { @@ -3731,14 +3628,13 @@ struct TinyDecoder { h = ggml_upscale(ctx, h, 2); // conv(64, 64, bias=False) - h = ggml_conv_2d(ctx, conv_3_w, h, 1, 1, 1, 1, 1, 1); + h = ggml_nn_conv_2d(ctx, h, conv_3_w, NULL, 1, 1, 1, 1); // Block(64, 64) h = final_block.forward(ctx, h); // conv(64, 3) - h = ggml_conv_2d(ctx, conv_final_w, h, 1, 1, 1, 1, 1, 1); - h = ggml_add(ctx, h, ggml_reshape_4d(ctx, conv_final_b, 1, 1, conv_final_b->ne[0], 1)); + h = ggml_nn_conv_2d(ctx, h, conv_final_w, conv_final_b, 1, 1, 1, 1); return h; } };