diff --git a/LICENSE.md b/LICENSE.md
new file mode 100644
index 00000000..261eeb9e
--- /dev/null
+++ b/LICENSE.md
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index ac37dd10..ce5c74ea 100644
--- a/README.md
+++ b/README.md
@@ -1,31 +1,49 @@
-# Kohya Trainer V4 Colab UI - VRAM 12GB
-### Best way to train Stable Diffusion model for peeps who didn't have good GPU
-
-Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)
-Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)
-You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb)
-
----
-## What is this?
----
-### **_Q: So what's differences between `Kohya Trainer` and other Stable Diffusion trainer out there?_**
-### A: **Kohya Trainer** have some new features like
-1. Using the U-Net learning
-2. Automatic captioning/tagging for every image automatically with BLIP/DeepDanbooru
-3. Implemented [NovelAI Aspect Ratio Bucketing Tool](https://github.com/NovelAI/novelai-aspect-ratio-bucketing) so you don't need to crop image dataset 512x512 ever again
-- Use the output of the second-to-last layer of CLIP (Text Encoder) instead of the last layer.
-- Learning at non-square resolutions (Aspect Ratio Bucketing) .
-- Extend token length from 75 to 225.
-4. By preparing a certain number of images (several hundred or more seems to be desirable), you can make learning even more flexible than with DreamBooth.
-5. It also support Hypernetwork learning
-6. `NEW!` 23/11 - Implemented Waifu Diffusion 1.4 Tagger for alternative DeepDanbooru to auto-tagging
-7.
-
-### **_Q: And what's differences between this notebook and other dreambooth notebook out there?_**
-### A: We're adding Quality of Life features such as:
-- Install **gallery-dl** to scrap images, so you can get your own dataset fast with google bandwidth
-- Huggingface Integration, here you can login to huggingface-hub and upload your trained model/dataset to huggingface
----
+# Kohya Trainer V6 - VRAM 12GB
+### The Best Way for People Without Good GPUs to Fine-Tune the Stable Diffusion Model
+
+This notebook has been adapted for use in Google Colab based on the [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb).
+This notebook was adapted by [Linaqruf](https://github.com/Linaqruf)
+You can find the latest update to the notebook [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb).
+
+## Overview
+- Fine tuning of Stable Diffusion's U-Net using Diffusers
+- Addressing improvements from the NovelAI article, such as using the output of the penultimate layer of CLIP (Text Encoder) instead of the last layer and learning at non-square resolutions with aspect ratio bucketing.
+- Extends token length from 75 to 225 and offers automatic caption and automatic tagging with BLIP, DeepDanbooru, and WD14Tagger
+- Supports hypernetwork learning and is compatible with Stable Diffusion v2.0 (base and 768/v)
+- By default, does not train Text Encoder for fine tuning of the entire model, but option to train Text Encoder is available.
+- Ability to make learning even more flexible than with DreamBooth by preparing a certain number of images (several hundred or more seems to be desirable).
+
+## Change Logs:
+##### v6 (6/12):
+- Temporary fix for an error when saving in the .safetensors format with some models. If you experienced this error with v5, please try v6.
+
+##### v5 (5/12):
+- Added support for the .safetensors format. Install safetensors with `pip install safetensors` and specify the `use_safetensors` option when saving.
+- Added the `log_prefix` option.
+- The cleaning script can now be used even when one of the captions or tags is missing.
+
+##### v4 (14/12):
+- The script name has changed to fine_tune.py.
+- Added the option `--train_text_encoder` to train the Text Encoder.
+- Added the option `--save_precision` to specify the data format of the saved checkpoint. Can be selected from float, fp16, or bf16.
+- Added the option `--save_state` to save the training state, including the optimizer. Can be resumed with the `--resume` option.
+
+##### v3 (29/11):
+- Requires Diffusers 0.9.0. To update it, run `pip install -U diffusers[torch]==0.9.0`.
+- Supports Stable Diffusion v2.0. Use the `--v2` option when training (and when pre-acquiring latents). If you are using 768-v-ema.ckpt or stable-diffusion-2 instead of stable-diffusion-v2-base, also use the `--v_parameterization` option when training.
+- Added options to specify the minimum and maximum resolutions of the bucket when pre-acquiring latents.
+- Modified the loss calculation formula.
+- Added options for the learning rate scheduler.
+- Added support for downloading Diffusers models directly from Hugging Face and for saving during training.
+- The cleaning script can now be used even when only one of the captions or tags is missing.
+- Added options for the learning rate scheduler.
+
+##### v2 (23/11):
+- Implemented Waifu Diffusion 1.4 Tagger for alternative DeepDanbooru for auto-tagging
+- Added a tagging script using WD14Tagger.
+- Fixed a bug that caused data to be shuffled twice.
+- Corrected spelling mistakes in the options for each script.
## Credit
[Kohya](https://twitter.com/kohya_ss) | Just for my part
+
diff --git a/colab_in_development/kohya-trainer-v3-deepdanbooru.ipynb b/colab_in_development/kohya-trainer-v3-deepdanbooru.ipynb
deleted file mode 100644
index 44fe87da..00000000
--- a/colab_in_development/kohya-trainer-v3-deepdanbooru.ipynb
+++ /dev/null
@@ -1,801 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "accelerator": "GPU",
- "gpuClass": "standard"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Kohya Trainer V3 - VRAM 12GB - DeepDanbooru\n",
- "###Best way to train Stable Diffusion model for peeps who didn't have good GPU"
- ],
- "metadata": {
- "id": "slgjeYgd6pWp"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)\n",
- "\n",
- "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)\n",
- "\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer-v3-deepdanbooru.ipynb)\n",
- "\n",
- "\n"
- ],
- "metadata": {
- "id": "gPgBR3KM6E-Z"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## What is this?\n"
- ],
- "metadata": {
- "id": "v3Qxv-rCXshE"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "---\n",
- "#####**_Q: So what's differences between `Kohya Trainer` and other diffusers out there?_**\n",
- "#####A: **Kohya Trainer** have some new features like\n",
- "1. Using the U-Net learning\n",
- "2. Automatic captioning/tagging for every image automatically with BLIP/DeepDanbooru\n",
- "3. Read all captions/tags created and put them in metadata.json\n",
- "4. Implemented [NovelAI Aspect Ratio Bucketing Tool](https://github.com/NovelAI/novelai-aspect-ratio-bucketing) so you don't need to crop image dataset 512x512 ever again\n",
- "- Use the output of the second-to-last layer of CLIP (Text Encoder) instead of the last layer.\n",
- "- Learning at non-square resolutions (Aspect Ratio Bucketing) .\n",
- "- Extend token length from 75 to 225.\n",
- "5. By preparing a certain number of images (several hundred or more seems to be desirable), you can make learning even more flexible than with DreamBooth.\n",
- "6. It also support Hypernetwork learning\n",
- "7. `NEW!` Implemented Waifu Diffusion 1.4 Tagger for alternative DeepDanbooru to auto-tagging.\n",
- "\n",
- "#####**_Q: And what's differences between this notebook and other dreambooth notebook out there?_**\n",
- "#####A: We're adding Quality of Life features such as:\n",
- "- Install **gallery-dl** to scrap images, so you can get your own dataset fast with google bandwidth\n",
- "- Huggingface Integration, here you can login to huggingface-hub and upload your trained model/dataset to huggingface\n",
- "---"
- ],
- "metadata": {
- "id": "gSSojWxg7cFP"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Install Dependencies"
- ],
- "metadata": {
- "id": "h3AuTNu6MFZk"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Diffuser\n",
- "%cd /content/\n",
- "!pip install --upgrade pip\n",
- "!pip install diffusers[torch]==0.7.2"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Aq5cjtG5nJ3Y"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Xformers (T4)\n",
- "%cd /content/\n",
- "!git clone https://github.com/openai/triton.git\n",
- "\n",
- "# Install Triton\n",
- "%cd /content/triton/python\n",
- "!pip install -e .\n",
- "\n",
- "# Install Xformers\n",
- "%pip install -qq https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.14/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Q_DPyXcDqv8J"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Install Kohya Trainer V3"
- ],
- "metadata": {
- "id": "tTVqCAgSmie4"
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "_u3q60di584x"
- },
- "outputs": [],
- "source": [
- "#@title Cloning Kohya Trainer V3\n",
- "%cd /content/\n",
- "!git clone https://github.com/Linaqruf/kohya-trainer"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Kohya Trainer V3 Requirement\n",
- "%cd /content/kohya-trainer\n",
- "!pip install -r requirements.txt"
- ],
- "metadata": {
- "cellView": "form",
- "id": "WNn0g1pnHfk5"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Danbooru Scraper"
- ],
- "metadata": {
- "id": "En9UUwGNMRMM"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install gallery-dl library\n",
- "!pip install -U gallery-dl"
- ],
- "metadata": {
- "id": "dBi4pk7hy-Jg",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Danbooru Scraper\n",
- "#@markdown **How this work?**\n",
- "\n",
- "#@markdown By using **gallery-dl** we can scrap or bulk download images on Internet, on this notebook we will scrap images from Danbooru using tag1 and tag2 as target scraping.\n",
- "%cd /content/kohya-trainer\n",
- "\n",
- "tag = \"hito_komoru \" #@param {type: \"string\"}\n",
- "tag2 = \"\" #@param {type: \"string\"}\n",
- "output_dir = \"/content/kohya-trainer/train_data\" \n",
- "\n",
- "if tag2 is not \"\":\n",
- " tag = tag + \"+\" + tag2\n",
- "else:\n",
- " tag = tag\n",
- "\n",
- "def danbooru_dl():\n",
- " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tag}+&z=5\" -D {output_dir}\n",
- "\n",
- "danbooru_dl()\n",
- "\n",
- "#@markdown The output directory will be on /content/kohya-trainer/train_data. We also will use this folder as target folder for training next step.\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Kt1GzntK_apb"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#DeepDanbooru3 for Autotagger\n",
- "We will skip BLIP Captioning section and only used DeepDanbooru for Autotagging.\n",
- "\n",
- "If you still want to use BLIP, please refer to the original article [here](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)"
- ],
- "metadata": {
- "id": "cSbB9CeqMwbF"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install DeepDanbooru\n",
- "%cd /content/\n",
- "!git clone https://github.com/KichangKim/DeepDanbooru kohya-trainer/deepdanbooru\n",
- "\n",
- "%cd /content/kohya-trainer/deepdanbooru\n",
- "!pip install -r requirements.txt\n",
- "!pip install ."
- ],
- "metadata": {
- "cellView": "form",
- "id": "AsLO2-REM8Yd"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install DeepDanbooru3 Model Weight\n",
- "%cd /content/kohya-trainer/deepdanbooru\n",
- "!wget -c https://github.com/KichangKim/deepdanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip -O deepdanbooruv3.zip\n",
- "!mkdir deepdanbooruv3\n",
- "!mv deepdanbooruv3.zip deepdanbooruv3"
- ],
- "metadata": {
- "cellView": "form",
- "id": "p8Y1SWWwUO26"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Unzip DeepDanbooru3 Model\n",
- "%cd /content/kohya-trainer/deepdanbooru/deepdanbooruv3\n",
- "!unzip deepdanbooruv3.zip \n",
- "!rm -rf deepdanbooruv3.zip"
- ],
- "metadata": {
- "cellView": "form",
- "id": "4H5vSQnFXhTO"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Batch Tag Interrogating and save it as (.txt)\n",
- "%cd /content/kohya-trainer/deepdanbooru/deepdanbooruv3\n",
- "!deepdanbooru evaluate /content/kohya-trainer/train_data \\\n",
- " --project-path /content/kohya-trainer/deepdanbooru/deepdanbooruv3 \\\n",
- " --allow-folder \\\n",
- " --save-txt"
- ],
- "metadata": {
- "cellView": "form",
- "id": "hibZK5NPTjZQ"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Create Metadata from tags collected\n",
- "%cd /content/kohya-trainer\n",
- "!python merge_dd_tags_to_metadata.py train_data meta_cap_dd.json"
- ],
- "metadata": {
- "cellView": "form",
- "id": "hz2Cmlf2ay9w"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clean Metadata.json (not sure it works)\n",
- "%cd /content/kohya-trainer\n",
- "!python clean_captions_and_tags.py train_data meta_cap_dd.json meta_clean.json"
- ],
- "metadata": {
- "cellView": "form",
- "id": "WFq28pPWjLpP"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Preparing Checkpoint"
- ],
- "metadata": {
- "id": "3gob9_OwTlwh"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Checkpoint\n",
- "%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
- "#@title Download Available Checkpoint\n",
- "\n",
- "def huggingface_checkpoint(url, checkpoint_name):\n",
- " #@markdown Insert your Huggingface token below\n",
- " user_token = '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' #@param {'type': 'string'}\n",
- " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
- " !wget -c --header={user_header} {url} -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
- "\n",
- "def custom_checkpoint(url, checkpoint_name):\n",
- " !wget {url} -O /checkpoint/{checkpoint_name}.ckpt\n",
- "\n",
- "def install_checkpoint():\n",
- " #@markdown Choose the models you want:\n",
- " Animefull_Final_Pruned= False #@param {'type':'boolean'}\n",
- " Waifu_Diffusion_V1_3 = False #@param {'type':'boolean'}\n",
- " Anything_V3_0_Pruned = True #@param {'type':'boolean'}\n",
- "\n",
- " if Animefull_Final_Pruned:\n",
- " huggingface_checkpoint(\"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt\", \"Animefull_Final_Pruned\")\n",
- " if Waifu_Diffusion_V1_3:\n",
- " huggingface_checkpoint(\"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float32.ckpt\", \"Waifu_Diffusion_V1_3\")\n",
- " if Anything_V3_0_Pruned:\n",
- " huggingface_checkpoint(\"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned.ckpt\", \"Anything_V3_0_Pruned\")\n",
- "\n",
- "install_checkpoint()"
- ],
- "metadata": {
- "cellView": "form",
- "id": "SoucgZQ6jgPQ"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Download Custom Checkpoint\n",
- "#@markdown If your checkpoint aren't provided on the cell above, you can insert your own here.\n",
- "\n",
- "ckptName = \"\" #@param {'type': 'string'}\n",
- "ckptURL = \"\" #@param {'type': 'string'}\n",
- "\n",
- "def custom_checkpoint(url, name):\n",
- " !wget -c {url} -O /content/kohya-trainer/{name}.ckpt\n",
- "\n",
- "def install_checkpoint():\n",
- " if ckptName and ckptURL is not \"\" :\n",
- " custom_checkpoint(ckptName, ckptURL)\n",
- "\n",
- "install_checkpoint()"
- ],
- "metadata": {
- "id": "vrQ3_jbFTrgL",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Prepare Training"
- ],
- "metadata": {
- "id": "15xUbLvQNN28"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title NovelAI Aspect Ratio Bucketing Script\n",
- "%cd /content/kohya-trainer\n",
- "\n",
- "model_dir= \"/content/kohya-trainer/checkpoint/Anything_V3_0_Pruned.ckpt\" #@param {'type' : 'string'} \n",
- "\n",
- "!python prepare_buckets_latents.py train_data meta_cap_dd.json meta_lat.json {model_dir} \\\n",
- " --batch_size 4 \\\n",
- " --max_resolution 512,512 \\\n",
- " --mixed_precision no"
- ],
- "metadata": {
- "cellView": "form",
- "id": "hhgatqF3leHJ"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Set config for Accelerate\n",
- "#@markdown #Hint\n",
- "\n",
- "#@markdown 1. **In which compute environment are you running?** ([0] This machine, [1] AWS (Amazon SageMaker)): `0`\n",
- "#@markdown 2. **Which type of machine are you using?** ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU [4] MPS): `0`\n",
- "#@markdown 3. **Do you want to run your training on CPU only (even if a GPU is available)?** [yes/NO]: `NO`\n",
- "#@markdown 4. **Do you want to use DeepSpeed?** [yes/NO]: `NO`\n",
- "#@markdown 5. **What GPU(s) (by id) should be used for training on this machine as a comma-seperated list?** [all] = `all`\n",
- "#@markdown 6. **Do you wish to use FP16 or BF16 (mixed precision)?** [NO/fp16/bf16]: `fp16`\n",
- "!accelerate config"
- ],
- "metadata": {
- "cellView": "form",
- "id": "RnjHb4wgD7vu"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Start Training\n",
- "\n"
- ],
- "metadata": {
- "id": "yHNbl3O_NSS0"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Training begin\n",
- "num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n",
- "model_path =\"/content/kohya-trainer/checkpoint/Anything_V3_0_Pruned.ckpt\" #@param {'type':'string'}\n",
- "output_dir =\"/content/kohya-trainer/fine_tuned\" #@param {'type':'string'}\n",
- "train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "learning_rate =\"2e-6\" #@param {'type':'string'}\n",
- "max_token_length = 225 #@param {'type':'integer'}\n",
- "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 10}\n",
- "mixed_precision = \"fp16\" #@param [\"fp16\", \"bp16\"] {allow-input: false}\n",
- "max_train_steps = 5000 #@param {'type':'integer'}\n",
- "# save_precision = \"fp16\" #@param [\"fp16\", \"bp16\", \"float\"] {allow-input: false}\n",
- "save_every_n_epochs = 10 #@param {'type':'integer'}\n",
- "# gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} fine_tune.py \\\n",
- " --pretrained_model_name_or_path={model_path} \\\n",
- " --in_json meta_lat.json \\\n",
- " --train_data_dir=train_data \\\n",
- " --output_dir={output_dir} \\\n",
- " --shuffle_caption \\\n",
- " --train_batch_size={train_batch_size} \\\n",
- " --learning_rate={learning_rate} \\\n",
- " --max_token_length={max_token_length} \\\n",
- " --clip_skip={clip_skip} \\\n",
- " --mixed_precision={mixed_precision} \\\n",
- " --max_train_steps={max_train_steps} \\\n",
- " --use_8bit_adam \\\n",
- " --xformers \\\n",
- " --gradient_checkpointing \\\n",
- " --save_every_n_epochs={save_every_n_epochs} \\\n",
- " --save_state #For Resume Training\n",
- " # --gradient_accumulation_steps {gradient_accumulation_steps} \\\n",
- " # --resume /content/kohya-trainer/checkpoint/last-state \\\n",
- " # --save_precision={save_precision} \\\n",
- "\n"
- ],
- "metadata": {
- "id": "X_Rd3Eh07xlA",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Miscellaneous"
- ],
- "metadata": {
- "id": "vqfgyL-thgdw"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Model Pruner\n",
- "#@markdown Do you want to Pruning model?\n",
- "\n",
- "prune = False #@param {'type':'boolean'}\n",
- "\n",
- "model_path = \"betabeet_5000_steps_2e-6.ckpt\" #@param {'type' : 'string'}\n",
- "if prune == True:\n",
- " import os\n",
- " if os.path.isfile('/content/prune-ckpt.py'):\n",
- " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
- " \n",
- " else:\n",
- " !wget https://raw.githubusercontent.com/prettydeep/Dreambooth-SD-ckpt-pruning/main/prune-ckpt.py\n",
- "\n",
- "\n",
- " !python prune-ckpt.py --ckpt {model_path}\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "LUOG7BzQVLKp"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= False #@param {'type':'boolean'}\n",
- "\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Huggingface_hub Integration"
- ],
- "metadata": {
- "id": "QtVP2le8PL2T"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Instruction:\n",
- "0. Of course you need a Huggingface Account first\n",
- "1. Create your huggingface model repository\n",
- "2. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "3. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells."
- ],
- "metadata": {
- "id": "tbKgmh_AO5NG"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Login to Huggingface hub\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "#@markdown Prepare your Huggingface token\n",
- "\n",
- "saved_token= \"\" #@param {'type': 'string'}\n",
- "\n",
- "if opt_out == False:\n",
- " from huggingface_hub import notebook_login\n",
- " notebook_login()\n",
- "\n",
- "else:\n",
- " display(HTML(f\"
This cell will not running because you choose to opt-out this cell.\"))"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Da7awoqAPJ3a"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit trained model to Huggingface"
- ],
- "metadata": {
- "id": "jypUkLWc48R_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Model\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "\n",
- "if opt_out == False:\n",
- " !pip install huggingface_hub\n",
- "\n",
- " %cd /content\n",
- "\n",
- " from huggingface_hub import notebook_login\n",
- "\n",
- " notebook_login()\n",
- "\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/hitokomoru\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))"
- ],
- "metadata": {
- "cellView": "form",
- "id": "182Law9oUiYN"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"hitokomoru\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-username\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"Push: hitokomoru-5000\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{model_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "87wG7QIZbtZE"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit dataset to huggingface"
- ],
- "metadata": {
- "id": "olP2yaK3OKcr"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Zip train_data\n",
- "\n",
- "%cd /content\n",
- "!zip -r /content/train_data /content/kohya-trainer/train_data"
- ],
- "metadata": {
- "id": "BZ8Nrx4-hoQQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Dataset\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "\n",
- "if opt_out == False:\n",
- " !pip install huggingface_hub\n",
- "\n",
- " %cd /content\n",
- "\n",
- " Repository_url = \"https://huggingface.co/datasets/Linaqruf/hitokomoru-tag\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))"
- ],
- "metadata": {
- "cellView": "form",
- "id": "QhL6UgqDOURK"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"hitokomoru-tag\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-name\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"Push: hitokomoru-tag\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{dataset_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "abHLg4I0Os5T"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
-}
\ No newline at end of file
diff --git a/colab_in_development/kohya-trainer-v3-for-resume-training.ipynb b/colab_in_development/kohya-trainer-v3-for-resume-training.ipynb
deleted file mode 100644
index 4faadb72..00000000
--- a/colab_in_development/kohya-trainer-v3-for-resume-training.ipynb
+++ /dev/null
@@ -1,682 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "gpuClass": "standard",
- "accelerator": "GPU"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Kohya Trainer V3 Resume Notebook\n",
- "###Notebook for resuming your latest training using [main notebook](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer-v3-stable.ipynb)"
- ],
- "metadata": {
- "id": "nhfqQvcc_Wur"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)\n",
- "\n",
- "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)\n",
- "\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/notebook_variant/kohya-trainer-v3-for-resume-training.ipynb)\n",
- "\n",
- "\n"
- ],
- "metadata": {
- "id": "GJZuQhj8_pid"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Install Dependencies"
- ],
- "metadata": {
- "id": "h3AuTNu6MFZk"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Diffuser\n",
- "%cd /content/\n",
- "!pip install --upgrade pip\n",
- "!pip install diffusers[torch]==0.7.2"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Aq5cjtG5nJ3Y"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install xformers\n",
- "\n",
- "from IPython.display import clear_output\n",
- "import time\n",
- "from IPython.display import HTML\n",
- "from subprocess import getoutput\n",
- "import os\n",
- "\n",
- "s = getoutput('nvidia-smi')\n",
- "\n",
- "if 'T4' in s:\n",
- " gpu = 'T4'\n",
- "elif 'P100' in s:\n",
- " gpu = 'P100'\n",
- "elif 'V100' in s:\n",
- " gpu = 'V100'\n",
- "elif 'A100' in s:\n",
- " gpu = 'A100'\n",
- "\n",
- "if (gpu=='T4'):\n",
- " %pip install -qq https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.14/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl\n",
- "elif (gpu=='P100'):\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/P100/xformers-0.0.13.dev0-py3-none-any.whl\n",
- "elif (gpu=='V100'):\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/V100/xformers-0.0.13.dev0-py3-none-any.whl\n",
- "elif (gpu=='A100'):\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/A100/xformers-0.0.13.dev0-py3-none-any.whl"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Q_DPyXcDqv8J"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Install Kohya Trainer V3"
- ],
- "metadata": {
- "id": "tTVqCAgSmie4"
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "_u3q60di584x"
- },
- "outputs": [],
- "source": [
- "#@title Cloning Kohya Trainer v3\n",
- "%cd /content/\n",
- "!git clone https://github.com/Linaqruf/kohya-trainer"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Kohya Trainer v3 Requirement\n",
- "%cd /content/kohya-trainer\n",
- "!pip install -r requirements.txt"
- ],
- "metadata": {
- "cellView": "form",
- "id": "WNn0g1pnHfk5"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Preparing Datasets and Save State Folder\n"
- ],
- "metadata": {
- "id": "pvtdNJ9nerfp"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Make sure you saved your datasets on Huggingface. Your datasets on huggingface should be saved these necessary files:\n",
- "- Folder `last-state`\n",
- "- Folder `train_data`\n",
- "- File `meta_lat.json`\n",
- "\n"
- ],
- "metadata": {
- "id": "mnxoqiZa8fub"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Dataset\n",
- "%cd /content\n",
- "\n",
- "dataset_url = \"https://huggingface.co/datasets/Linaqruf/hitokomoru-tag\" #@param {'type': 'string'}\n",
- "!git lfs install\n",
- "!git clone {dataset_url}\n",
- "\n"
- ],
- "metadata": {
- "id": "NwBkqbYRepXL",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Preparing Pre-trained Model"
- ],
- "metadata": {
- "id": "3gob9_OwTlwh"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Pre-trained Model \n",
- "%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
- "\n",
- "#@title Install Pre-trained Model \n",
- "\n",
- "installModels=[]\n",
- "\n",
- "\n",
- "#@markdown ### Available Model\n",
- "#@markdown Select one of available pretrained model to download:\n",
- "modelUrl = [\"\", \\\n",
- " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt\", \\\n",
- " \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/modelsfw-pruned.ckpt\", \\\n",
- " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp16.ckpt\", \\\n",
- " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned-fp32.ckpt\", \\\n",
- " \"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned.ckpt\", \\\n",
- " \"https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt\" \\\n",
- " \"https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt\", \\\n",
- " \"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float32.ckpt\"]\n",
- "modelList = [\"\", \\\n",
- " \"Animefull-final-pruned\", \\\n",
- " \"Animesfw-final-pruned\", \\\n",
- " \"Anything-V3.0-pruned-fp16\", \\\n",
- " \"Anything-V3.0-pruned-fp32\", \\\n",
- " \"Anything-V3.0-pruned\", \\\n",
- " \"Stable-Diffusion-v1-4\", \\\n",
- " \"Stable-Diffusion-v1-5-pruned-emaonly\" \\\n",
- " \"Waifu-Diffusion-v1-3-fp32\"]\n",
- "modelName = \"Anything-V3.0-pruned-fp32\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
- "\n",
- "#@markdown ### Custom model\n",
- "#@markdown The model URL should be a direct download link.\n",
- "customName = \"\" #@param {'type': 'string'}\n",
- "customUrl = \"\"#@param {'type': 'string'}\n",
- "\n",
- "if customName == \"\" or customUrl == \"\":\n",
- " pass\n",
- "else:\n",
- " installModels.append((customName, customUrl))\n",
- "\n",
- "if modelName != \"\":\n",
- " # Map model to URL\n",
- " installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n",
- "\n",
- "def install_aria():\n",
- " if not os.path.exists('/usr/bin/aria2c'):\n",
- " !apt install -y -qq aria2\n",
- "\n",
- "def install(checkpoint_name, url):\n",
- " if url.startswith(\"https://drive.google.com\"):\n",
- " !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n",
- " elif url.startswith(\"magnet:?\"):\n",
- " install_aria()\n",
- " !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n",
- " else:\n",
- " user_token = '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'\n",
- " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
- " !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
- "\n",
- "def install_checkpoint():\n",
- " for model in installModels:\n",
- " install(model[0], model[1])\n",
- "install_checkpoint()\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "SoucgZQ6jgPQ"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Prepare Training"
- ],
- "metadata": {
- "id": "15xUbLvQNN28"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Set config for `!Accelerate`\n",
- "#@markdown #Hint\n",
- "\n",
- "#@markdown 1. **In which compute environment are you running?** ([0] This machine, [1] AWS (Amazon SageMaker)): `0`\n",
- "#@markdown 2. **Which type of machine are you using?** ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU [4] MPS): `0`\n",
- "#@markdown 3. **Do you want to run your training on CPU only (even if a GPU is available)?** [yes/NO]: `NO`\n",
- "#@markdown 4. **Do you want to use DeepSpeed?** [yes/NO]: `NO`\n",
- "#@markdown 5. **What GPU(s) (by id) should be used for training on this machine as a comma-seperated list?** [all] = `all`\n",
- "#@markdown 6. **Do you wish to use FP16 or BF16 (mixed precision)?** [NO/fp16/bf16]: `fp16`\n",
- "!accelerate config"
- ],
- "metadata": {
- "cellView": "form",
- "id": "RnjHb4wgD7vu"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Start Training"
- ],
- "metadata": {
- "id": "V6NhtBhk4kOZ"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Training begin\n",
- "num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n",
- "pre_trained_model_path =\"/content/kohya-trainer/checkpoint/Anything_V3_0_Pruned.ckpt\" #@param {'type':'string'}\n",
- "meta_lat_json_dir = \"/content/granblue-fantasy-tag/meta_lat.json\" #@param {'type':'string'}\n",
- "train_data_dir = \"/content/granblue-fantasy-tag/train_data\" #@param {'type':'string'}\n",
- "output_dir =\"/content/kohya-trainer/fine_tuned\" #@param {'type':'string'}\n",
- "resume_path = \"/content/hitokomoru/last-state\" #@param {'type':'string'}\n",
- "train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "learning_rate =\"2e-6\" #@param {'type':'string'}\n",
- "max_token_length = \"225\" #@param [\"150\", \"225\"] {allow-input: false}\n",
- "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 10}\n",
- "mixed_precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n",
- "max_train_steps = 5000 #@param {'type':'integer'}\n",
- "# save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n",
- "save_every_n_epochs = 10 #@param {'type':'integer'}\n",
- "gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} fine_tune.py \\\n",
- " --pretrained_model_name_or_path={pre_trained_model_path} \\\n",
- " --in_json {meta_lat_json_dir} \\\n",
- " --train_data_dir={train_data_dir} \\\n",
- " --output_dir={output_dir} \\\n",
- " --shuffle_caption \\\n",
- " --train_batch_size={train_batch_size} \\\n",
- " --learning_rate={learning_rate} \\\n",
- " --max_token_length={max_token_length} \\\n",
- " --clip_skip={clip_skip} \\\n",
- " --mixed_precision={mixed_precision} \\\n",
- " --max_train_steps={max_train_steps} \\\n",
- " --use_8bit_adam \\\n",
- " --xformers \\\n",
- " --gradient_checkpointing \\\n",
- " --save_state \\\n",
- " --resume {resume_path} \\\n",
- " --gradient_accumulation_steps {gradient_accumulation_steps} \n",
- " # --save_precision={save_precision} \\\n"
- ],
- "metadata": {
- "id": "X_Rd3Eh07xlA",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Model Pruner\n",
- "#@markdown Do you want to Pruning model?\n",
- "\n",
- "prune = False #@param {'type':'boolean'}\n",
- "\n",
- "model_path = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
- "if prune == True:\n",
- " import os\n",
- " if os.path.isfile('/content/prune-ckpt.py'):\n",
- " pass\n",
- " else:\n",
- " !wget https://raw.githubusercontent.com/prettydeep/Dreambooth-SD-ckpt-pruning/main/prune-ckpt.py\n",
- "\n",
- "\n",
- " !python prune-ckpt.py --ckpt {model_path}\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "usmkmqEbgaRi"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= False #@param {'type':'boolean'}\n",
- "\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Instruction:\n",
- "0. Of course you need a Huggingface Account first\n",
- "1. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "2. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells."
- ],
- "metadata": {
- "id": "QtVP2le8PL2T"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Login to Huggingface hub\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "#@markdown Prepare your Huggingface token\n",
- "\n",
- "saved_token= \"save-your-write-token-here\" #@param {'type': 'string'}\n",
- "\n",
- "if opt_out == False:\n",
- " from huggingface_hub import notebook_login\n",
- " notebook_login()\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Da7awoqAPJ3a"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit trained model to Huggingface"
- ],
- "metadata": {
- "id": "jypUkLWc48R_"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "###Instruction:\n",
- "0. Create huggingface repository for model\n",
- "1. Clone your model to this colab session\n",
- "2. Move these necessary file to your repository to save your trained model to huggingface\n",
- "\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- File `epoch-nnnnn.ckpt` and/or\n",
- "- File `last.ckpt`, \n",
- "\n",
- "4. Commit your model to huggingface"
- ],
- "metadata": {
- "id": "nrulkTYg-JME"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Model\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/alphanime-diffusion\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "182Law9oUiYN"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title `NEW!` Move trained model to cloned repository\n",
- "#@markdown Opt-out this cell when run all\n",
- "\n",
- "import shutil\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "#@markdown Fill necessary file/folder path in the textbox given below. You need to atleast already cloned models and/or datasets from huggingface.\n",
- "\n",
- "model_path = \"/content/granblue-fantasy\" #@param {'type' : 'string'}\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "#model\n",
- "last_pruned_ckpt = \"\" #@param {'type' : 'string'}\n",
- "last_ckpt = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
- "\n",
- "if opt_out== False:\n",
- " if os.path.isfile(last_pruned_ckpt):\n",
- " shutil.move(last_pruned_ckpt,model_path)\n",
- " else:\n",
- " pass\n",
- "\n",
- " if os.path.isfile(last_ckpt):\n",
- " shutil.move(last_ckpt,model_path)\n",
- " else:\n",
- " pass\n",
- "\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "vCPfu6ss-QPT"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"alphanime-diffusion\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-name\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"this is commit message\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{model_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " pass"
- ],
- "metadata": {
- "cellView": "form",
- "id": "87wG7QIZbtZE"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit dataset to huggingface"
- ],
- "metadata": {
- "id": "olP2yaK3OKcr"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "###Instruction:\n",
- "\n",
- "Move these necessary file to your datasets so that you can do resume training next time without rebuild your dataset with this notebook\n",
- "\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- Folder `last-state`\n",
- "\n",
- "If the old `last-state` folder exists, you can delete it or rename it to something else, such as `hews-5000-state` so that you can remember your last step each time you want to continue training.\n"
- ],
- "metadata": {
- "id": "jiSb0z2CVtc_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title `NEW!` Move datasets to cloned repository\n",
- "import shutil\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "#@markdown Opt-out this cell when run all\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "#@markdown Fill necessary file/folder path in the textbox given below. You need to atleast already cloned models and/or datasets from huggingface.\n",
- "\n",
- "datasets_path = \"/content/granblue-fantasy-tag\" #@param {'type' : 'string'}\n",
- "\n",
- "#datasets\n",
- "save_state_dir = \"/content/kohya-trainer/fine_tuned/last-state\" #@param {'type' : 'string'}\n",
- "\n",
- "if opt_out == False:\n",
- " if os.path.isdir(save_state_dir):\n",
- " shutil.move(save_state_dir,datasets_path)\n",
- " else:\n",
- " pass\n",
- "\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Nkz2HoRYW3Ao"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"alphanime-diffusion-tag\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-name\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"this is commit message\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{dataset_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " pass"
- ],
- "metadata": {
- "id": "abHLg4I0Os5T",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
-}
\ No newline at end of file
diff --git a/colab_in_development/kohya-trainer-v3-wd-1-4-tagger-for-RUN-ALL.ipynb b/colab_in_development/kohya-trainer-v3-wd-1-4-tagger-for-RUN-ALL.ipynb
deleted file mode 100644
index bd00d57a..00000000
--- a/colab_in_development/kohya-trainer-v3-wd-1-4-tagger-for-RUN-ALL.ipynb
+++ /dev/null
@@ -1,912 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "accelerator": "GPU",
- "gpuClass": "standard"
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Kohya Trainer V3 - VRAM 12GB - `RUN ALL` Notebook\n",
- "###Best way to train Stable Diffusion model for peeps who didn't have good GPU"
- ],
- "metadata": {
- "id": "slgjeYgd6pWp"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)\n",
- "\n",
- "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)\n",
- "\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer-v3-wd-1-4-tagger-for-RUN-ALL.ipynb)\n",
- "\n",
- "\n"
- ],
- "metadata": {
- "id": "gPgBR3KM6E-Z"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "## What is this?\n"
- ],
- "metadata": {
- "id": "v3Qxv-rCXshE"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "---\n",
- "#####**_Q: So what's differences between `Kohya Trainer` and other diffusers out there?_**\n",
- "#####A: **Kohya Trainer** have some new features like\n",
- "1. Using the U-Net learning\n",
- "2. Automatic captioning/tagging for every image automatically with BLIP/DeepDanbooru\n",
- "3. Read all captions/tags created and put them in metadata.json\n",
- "4. Implemented [NovelAI Aspect Ratio Bucketing Tool](https://github.com/NovelAI/novelai-aspect-ratio-bucketing) so you don't need to crop image dataset 512x512 ever again\n",
- "- Use the output of the second-to-last layer of CLIP (Text Encoder) instead of the last layer.\n",
- "- Learning at non-square resolutions (Aspect Ratio Bucketing) .\n",
- "- Extend token length from 75 to 225.\n",
- "5. By preparing a certain number of images (several hundred or more seems to be desirable), you can make learning even more flexible than with DreamBooth.\n",
- "6. It also support Hypernetwork learning\n",
- "7. `NEW!` 23/11 - Implemented Waifu Diffusion 1.4 Tagger for alternative DeepDanbooru to auto-tagging.\n",
- "\n",
- "#####**_Q: And what's differences between this notebook and other dreambooth notebook out there?_**\n",
- "#####A: We're adding Quality of Life features such as:\n",
- "- Install **gallery-dl** to scrap images, so you can get your own dataset fast with google bandwidth\n",
- "- Huggingface Integration, here you can login to huggingface-hub and upload your trained model/dataset to huggingface\n",
- "---"
- ],
- "metadata": {
- "id": "gSSojWxg7cFP"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Install Dependencies"
- ],
- "metadata": {
- "id": "h3AuTNu6MFZk"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Diffuser\n",
- "%cd /content/\n",
- "!pip install --upgrade pip\n",
- "!pip install diffusers[torch]==0.7.2"
- ],
- "metadata": {
- "id": "Aq5cjtG5nJ3Y",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Xformers (T4)\n",
- "%cd /content/\n",
- "from IPython.display import clear_output\n",
- "import time\n",
- "from IPython.display import HTML\n",
- "from subprocess import getoutput\n",
- "import os\n",
- "\n",
- "s = getoutput('nvidia-smi')\n",
- "\n",
- "if 'T4' in s:\n",
- " gpu = 'T4'\n",
- "elif 'P100' in s:\n",
- " gpu = 'P100'\n",
- "elif 'V100' in s:\n",
- " gpu = 'V100'\n",
- "elif 'A100' in s:\n",
- " gpu = 'A100'\n",
- "\n",
- "if (gpu=='T4'):\n",
- " %pip install -qq https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.14/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl\n",
- "elif (gpu=='P100'):\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/P100/xformers-0.0.13.dev0-py3-none-any.whl\n",
- "elif (gpu=='V100'):\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/V100/xformers-0.0.13.dev0-py3-none-any.whl\n",
- "elif (gpu=='A100'):\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/A100/xformers-0.0.13.dev0-py3-none-any.whl"
- ],
- "metadata": {
- "id": "Q_DPyXcDqv8J",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Install Kohya Trainer v3"
- ],
- "metadata": {
- "id": "tTVqCAgSmie4"
- }
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "_u3q60di584x",
- "cellView": "form"
- },
- "outputs": [],
- "source": [
- "#@title Cloning Kohya Trainer v3\n",
- "%cd /content/\n",
- "!git clone https://github.com/Linaqruf/kohya-trainer"
- ]
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Kohya Trainer v3 Requirement\n",
- "%cd /content/kohya-trainer\n",
- "!pip install -r requirements.txt"
- ],
- "metadata": {
- "id": "WNn0g1pnHfk5",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Set config for Accelerate\n",
- "#@markdown #Hint\n",
- "\n",
- "#@markdown 1. **In which compute environment are you running?** ([0] This machine, [1] AWS (Amazon SageMaker)): `0`\n",
- "#@markdown 2. **Which type of machine are you using?** ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU [4] MPS): `0`\n",
- "#@markdown 3. **Do you want to run your training on CPU only (even if a GPU is available)?** [yes/NO]: `NO`\n",
- "#@markdown 4. **Do you want to use DeepSpeed?** [yes/NO]: `NO`\n",
- "#@markdown 5. **What GPU(s) (by id) should be used for training on this machine as a comma-seperated list?** [all] = `all`\n",
- "#@markdown 6. **Do you wish to use FP16 or BF16 (mixed precision)?** [NO/fp16/bf16]: `fp16`\n",
- "!accelerate config"
- ],
- "metadata": {
- "cellView": "form",
- "id": "z3kiEc1LNjHp"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Huggingface_hub Integration"
- ],
- "metadata": {
- "id": "HeOtiZoENxGi"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Instruction:\n",
- "0. Of course you need a Huggingface Account first\n",
- "1. Create your huggingface model repository\n",
- "2. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "3. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells."
- ],
- "metadata": {
- "id": "eSxZGxnfOGqz"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Login to Huggingface hub\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "#@markdown Prepare your Huggingface token\n",
- "\n",
- "saved_token= \"save-your-write-token-here\" #@param {'type': 'string'}\n",
- "\n",
- "if opt_out == False:\n",
- " from huggingface_hub import notebook_login\n",
- " notebook_login()\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))"
- ],
- "metadata": {
- "cellView": "form",
- "id": "yPUfQTD8N3fB"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Collecting datasets\n",
- "You can either upload your datasets to this notebook or use image scraper below to bulk download images from danbooru.\n",
- "\n",
- "If you want to use your own datasets, make sure to put them in a folder titled `train_data` in `content/kohya-trainer`\n",
- "\n",
- "This is to make the training process easier because the folder that will be used for training is in `content/kohya-trainer/train-data`."
- ],
- "metadata": {
- "id": "En9UUwGNMRMM"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install `gallery-dl` library\n",
- "!pip install -U gallery-dl"
- ],
- "metadata": {
- "id": "dBi4pk7hy-Jg",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Danbooru Scraper\n",
- "#@markdown **How this work?**\n",
- "\n",
- "#@markdown By using **gallery-dl** we can scrap or bulk download images on Internet, on this notebook we will scrap images from Danbooru using tag1 and tag2 as target scraping.\n",
- "%cd /content/kohya-trainer\n",
- "\n",
- "tag = \"minaba_hideo\" #@param {type: \"string\"}\n",
- "tag2 = \"granblue_fantasy\" #@param {type: \"string\"}\n",
- "output_dir = \"/content/kohya-trainer/train_data\" \n",
- "\n",
- "if tag2 is not \"\":\n",
- " tag = tag + \"+\" + tag2\n",
- "else:\n",
- " tag = tag\n",
- "\n",
- "print(tag)\n",
- "\n",
- "def danbooru_dl():\n",
- " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tag}+&z=5\" -D {output_dir}\n",
- "\n",
- "danbooru_dl()\n",
- "\n",
- "#@markdown The output directory will be on /content/kohya-trainer/train_data. We also will use this folder as target folder for training next step.\n",
- "\n"
- ],
- "metadata": {
- "id": "Kt1GzntK_apb",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#`(NEW)` Waifu Diffusion 1.4 Autotagger"
- ],
- "metadata": {
- "id": "SoPUJaTpTusz"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Tensorflow\n",
- "%cd /content/\n",
- "!pip install tensorflow"
- ],
- "metadata": {
- "cellView": "form",
- "id": "POJhWn28XrPs"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Download Weight\n",
- "%cd /content/kohya-trainer/\n",
- "\n",
- "import os\n",
- "import shutil\n",
- "\n",
- "def huggingface_dl(url, weight):\n",
- " user_token = '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'\n",
- " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
- " !wget -c --header={user_header} {url} -O /content/kohya-trainer/wd14tagger-weight/{weight}\n",
- "\n",
- "def download_weight():\n",
- " !mkdir /content/kohya-trainer/wd14tagger-weight/\n",
- " huggingface_dl(\"https://huggingface.co/Linaqruf/personal_backup/resolve/main/wd14tagger-weight/wd14Tagger.zip\", \"wd14Tagger.zip\")\n",
- " \n",
- " !unzip /content/kohya-trainer/wd14tagger-weight/wd14Tagger.zip -d /content/kohya-trainer/wd14tagger-weight\n",
- "\n",
- " # Destination path \n",
- " destination = '/content/kohya-trainer/wd14tagger-weight'\n",
- "\n",
- " if os.path.isfile('/content/kohya-trainer/tag_images_by_wd14_tagger.py'):\n",
- " # Move the content of \n",
- " # source to destination \n",
- " shutil.move(\"tag_images_by_wd14_tagger.py\", destination) \n",
- " else:\n",
- " pass\n",
- "\n",
- "download_weight()"
- ],
- "metadata": {
- "cellView": "form",
- "id": "WDSlAEHzT2Im"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Start Autotagger\n",
- "%cd /content/kohya-trainer/wd14tagger-weight\n",
- "!python tag_images_by_wd14_tagger.py --batch_size 4 /content/kohya-trainer/train_data\n",
- "\n",
- "#@markdown Args list:\n",
- "#@markdown - `--train_data_dir` : directory for training images\n",
- "#@markdown - `--model` : model path to load\n",
- "#@markdown - `--tag_csv` : csv file for tag\n",
- "#@markdown - `--thresh` : threshold of confidence to add a tag\n",
- "#@markdown - `--batch_size` : batch size in inference\n",
- "#@markdown - `--model` : model path to load\n",
- "#@markdown - `--caption_extension` : extension of caption file\n",
- "#@markdown - `--debug` : debug mode\n"
- ],
- "metadata": {
- "id": "hibZK5NPTjZQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Create Metadata.json\n",
- "%cd /content/kohya-trainer\n",
- "!python merge_dd_tags_to_metadata.py /content/drive/MyDrive/train_data meta_cap_dd.json"
- ],
- "metadata": {
- "id": "hz2Cmlf2ay9w",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Preparing Checkpoint"
- ],
- "metadata": {
- "id": "3gob9_OwTlwh"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Install Checkpoint\n",
- "%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
- "#@title Download Available Checkpoint\n",
- "\n",
- "def huggingface_checkpoint(url, checkpoint_name):\n",
- " #@markdown Insert your Huggingface token below\n",
- " user_token = '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' #@param {'type': 'string'}\n",
- " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
- " !wget -c --header={user_header} {url} -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
- "\n",
- "def custom_checkpoint(url, checkpoint_name):\n",
- " !wget {url} -O /checkpoint/{checkpoint_name}.ckpt\n",
- "\n",
- "def install_checkpoint():\n",
- " #@markdown Choose the models you want:\n",
- " Animefull_Final_Pruned= False #@param {'type':'boolean'}\n",
- " Waifu_Diffusion_V1_3 = False #@param {'type':'boolean'}\n",
- " Anything_V3_0_Pruned = True #@param {'type':'boolean'}\n",
- "\n",
- " if Animefull_Final_Pruned:\n",
- " huggingface_checkpoint(\"https://huggingface.co/Linaqruf/personal_backup/resolve/main/animeckpt/model-pruned.ckpt\", \"Animefull_Final_Pruned\")\n",
- " if Waifu_Diffusion_V1_3:\n",
- " huggingface_checkpoint(\"https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float32.ckpt\", \"Waifu_Diffusion_V1_3\")\n",
- " if Anything_V3_0_Pruned:\n",
- " huggingface_checkpoint(\"https://huggingface.co/Linaqruf/anything-v3.0/resolve/main/Anything-V3.0-pruned.ckpt\", \"Anything_V3_0_Pruned\")\n",
- "\n",
- "install_checkpoint()"
- ],
- "metadata": {
- "id": "SoucgZQ6jgPQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Download Custom Checkpoint\n",
- "#@markdown If your checkpoint aren't provided on the cell above, you can insert your own here.\n",
- "\n",
- "ckptName = \"\" #@param {'type': 'string'}\n",
- "ckptURL = \"\" #@param {'type': 'string'}\n",
- "\n",
- "def custom_checkpoint(url, name):\n",
- " !wget -c {url} -O /content/kohya-trainer/{name}.ckpt\n",
- "\n",
- "def install_checkpoint():\n",
- " if ckptName and ckptURL is not \"\" :\n",
- " custom_checkpoint(ckptName, ckptURL)\n",
- "\n",
- "install_checkpoint()"
- ],
- "metadata": {
- "id": "vrQ3_jbFTrgL",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Prepare Training"
- ],
- "metadata": {
- "id": "15xUbLvQNN28"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title NovelAI Aspect Ratio Bucketing Script\n",
- "%cd /content/kohya-trainer\n",
- "\n",
- "model_dir= \"/content/kohya-trainer/checkpoint/Anything_V3_0_Pruned.ckpt\" #@param {'type' : 'string'} \n",
- "\n",
- "!python prepare_buckets_latents.py train_data meta_cap_dd.json meta_lat.json {model_dir} \\\n",
- " --batch_size 4 \\\n",
- " --max_resolution 512,512 \\\n",
- " --mixed_precision no"
- ],
- "metadata": {
- "id": "hhgatqF3leHJ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "# Start Training\n",
- "\n"
- ],
- "metadata": {
- "id": "yHNbl3O_NSS0"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Dataset\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " Repository_url = \"https://huggingface.co/datasets/Linaqruf/granblue-fantasy-tag\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))"
- ],
- "metadata": {
- "id": "IqAPuG-4MlyE",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title `NEW!` Move datasets to cloned repository (for backup)\n",
- "#@markdown Fill necessary file/folder path in the textbox given below. You need to atleast already cloned models and/or datasets from huggingface.\n",
- "import shutil\n",
- "\n",
- "datasets_path = \"/content/granblue-fantasy-tag\" #@param {'type' : 'string'}\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "\n",
- "train_data = \"/content/drive/MyDrive/train_data\"\n",
- "meta_cap_dd = \"/content/kohya-trainer/meta_cap_dd.json\"\n",
- "meta_lat = \"/content/kohya-trainer/meta_lat.json\"\n",
- "\n",
- "\n",
- "if os.path.isdir(train_data):\n",
- " shutil.move(train_data,datasets_path)\n",
- "else:\n",
- " pass\n",
- "\n",
- "if os.path.isdir(meta_cap_dd):\n",
- " shutil.move(meta_cap_dd,datasets_path)\n",
- "else:\n",
- " pass\n",
- "\n",
- "if os.path.isfile(meta_lat):\n",
- " shutil.move(meta_lat,datasets_path)\n",
- "else:\n",
- " pass\n",
- "# shutil.move(last_ckpt,model_path)\n",
- "# shutil.move(save_state,datasets_path)\n",
- "# shutil.move(train_data,datasets_path)\n",
- "# shutil.move(meta_cap_dd,datasets_path)\n",
- "# shutil.move(meta_lat,datasets_path)\n"
- ],
- "metadata": {
- "id": "t9OhmPtDMzUe",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit datasets to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"granblue-fantasy-tag\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"post: granblue datasets\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{dataset_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))\n"
- ],
- "metadata": {
- "id": "jlV7y3IeNHJz",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Training begin\n",
- "num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n",
- "pre_trained_model_path =\"/content/kohya-trainer/checkpoint/Anything_V3_0_Pruned.ckpt\" #@param {'type':'string'}\n",
- "meta_lat_json_dir = \"/content/granblue-fantasy-tag/meta_lat.json\" #@param {'type':'string'}\n",
- "train_data_dir = \"/content/granblue-fantasy-tag/train_data\" #@param {'type':'string'}\n",
- "output_dir =\"/content/kohya-trainer/fine_tuned\" #@param {'type':'string'}\n",
- "train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "learning_rate =\"2e-6\" #@param {'type':'string'}\n",
- "max_token_length = 225 #@param {'type':'integer'}\n",
- "clip_skip = 2 #@param {type: \"slider\", min: 1, max: 10}\n",
- "mixed_precision = \"fp16\" #@param [\"fp16\", \"bp16\"] {allow-input: false}\n",
- "max_train_steps = 10000 #@param {'type':'integer'}\n",
- "# save_precision = \"fp16\" #@param [\"fp16\", \"bp16\", \"float\"] {allow-input: false}\n",
- "save_every_n_epochs = 0 #@param {'type':'integer'}\n",
- "# gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} fine_tune.py \\\n",
- " --pretrained_model_name_or_path={pre_trained_model_path} \\\n",
- " --in_json {meta_lat_json_dir} \\\n",
- " --train_data_dir={train_data_dir} \\\n",
- " --output_dir={output_dir} \\\n",
- " --shuffle_caption \\\n",
- " --train_batch_size={train_batch_size} \\\n",
- " --learning_rate={learning_rate} \\\n",
- " --max_token_length={max_token_length} \\\n",
- " --clip_skip={clip_skip} \\\n",
- " --mixed_precision={mixed_precision} \\\n",
- " --max_train_steps={max_train_steps} \\\n",
- " --use_8bit_adam \\\n",
- " --xformers \\\n",
- " --gradient_checkpointing \\\n",
- " --save_state #For Resume Training\n",
- " # --resume /content/granblue-fantasy-tag/last-state \n",
- " # --save_precision={save_precision} \\\n",
- " # --gradient_accumulation_steps {gradient_accumulation_steps} \\\n",
- " \n"
- ],
- "metadata": {
- "id": "4yNHC9FAOPGs",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Miscellaneous"
- ],
- "metadata": {
- "id": "vqfgyL-thgdw"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Model Pruner\n",
- "#@markdown Do you want to Pruning model?\n",
- "\n",
- "prune = True #@param {'type':'boolean'}\n",
- "\n",
- "model_path = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
- "if prune == True:\n",
- " import os\n",
- " if os.path.isfile('/content/prune-ckpt.py'):\n",
- " pass\n",
- " else:\n",
- " !wget https://raw.githubusercontent.com/prettydeep/Dreambooth-SD-ckpt-pruning/main/prune-ckpt.py\n",
- "\n",
- "\n",
- " !python prune-ckpt.py --ckpt {model_path}\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "DWjy8ubtOTed"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= False #@param {'type':'boolean'}\n",
- "\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit trained model to Huggingface"
- ],
- "metadata": {
- "id": "jypUkLWc48R_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Model\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/granblue-fantasy\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))"
- ],
- "metadata": {
- "id": "182Law9oUiYN",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title `NEW!` Move trained model and save state to cloned repository\n",
- "#@markdown Fill necessary file/folder path in the textbox given below. You need to atleast already cloned models and/or datasets from huggingface.\n",
- "import shutil\n",
- "\n",
- "model_path = \"/content/granblue-fantasy\" #@param {'type' : 'string'}\n",
- "\n",
- "%cd /content/kohya-trainer\n",
- "#model\n",
- "last_pruned_ckpt = \"/content/kohya-trainer/fine_tuned/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
- "last_ckpt = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
- "\n",
- "#datasets\n",
- "save_state = \"/content/kohya-trainer/fine_tuned/last-state\" #@param {'type' : 'string'}\n",
- "\n",
- "if os.path.isfile(last_pruned_ckpt):\n",
- " shutil.move(last_pruned_ckpt,model_path)\n",
- "else:\n",
- " pass\n",
- "\n",
- "if os.path.isfile(last_ckpt):\n",
- " shutil.move(last_ckpt,model_path)\n",
- "else:\n",
- " pass\n",
- "\n",
- "if os.path.isdir(save_state):\n",
- " shutil.move(save_state,datasets_path)\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "id": "SHmaokxXOmYG",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit Trained Model to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"granblue-fantasy\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"push granblue fantasy model 20k\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{model_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "CWNxpSLzOpeT"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Commit Datasets to Huggingface Again \n",
- "#@markdown Opt-out this cell when run all\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"granblue-fantasy-tag\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"push save state\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{dataset_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " display(HTML(f\"This cell will not running because you choose to opt-out this cell.\"))\n"
- ],
- "metadata": {
- "id": "flKzXY4fTUeq",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- }
- ]
-}
\ No newline at end of file
diff --git a/diffuser_fine_tuning/clean_captions_and_tags.py b/diffuser_fine_tuning/clean_captions_and_tags.py
new file mode 100644
index 00000000..76ede349
--- /dev/null
+++ b/diffuser_fine_tuning/clean_captions_and_tags.py
@@ -0,0 +1,125 @@
+# このスクリプトのライセンスは、Apache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+import argparse
+import glob
+import os
+import json
+
+from tqdm import tqdm
+
+
+def clean_tags(image_key, tags):
+ # replace '_' to ' '
+ tags = tags.replace('_', ' ')
+
+ # remove rating: deepdanbooruのみ
+ tokens = tags.split(", rating")
+ if len(tokens) == 1:
+ # WD14 taggerのときはこちらになるのでメッセージは出さない
+ # print("no rating:")
+ # print(f"{image_key} {tags}")
+ pass
+ else:
+ if len(tokens) > 2:
+ print("multiple ratings:")
+ print(f"{image_key} {tags}")
+ tags = tokens[0]
+
+ return tags
+
+
+# 上から順に検索、置換される
+# ('置換元文字列', '置換後文字列')
+CAPTION_REPLACEMENTS = [
+ ('anime anime', 'anime'),
+ ('young ', ''),
+ ('anime girl', 'girl'),
+ ('cartoon female', 'girl'),
+ ('cartoon lady', 'girl'),
+ ('cartoon character', 'girl'), # a or ~s
+ ('cartoon woman', 'girl'),
+ ('cartoon women', 'girls'),
+ ('cartoon girl', 'girl'),
+ ('anime female', 'girl'),
+ ('anime lady', 'girl'),
+ ('anime character', 'girl'), # a or ~s
+ ('anime woman', 'girl'),
+ ('anime women', 'girls'),
+ ('lady', 'girl'),
+ ('female', 'girl'),
+ ('woman', 'girl'),
+ ('women', 'girls'),
+ ('people', 'girls'),
+ ('person', 'girl'),
+ ('a cartoon figure', 'a figure'),
+ ('a cartoon image', 'an image'),
+ ('a cartoon picture', 'a picture'),
+ ('an anime cartoon image', 'an image'),
+ ('a cartoon anime drawing', 'a drawing'),
+ ('a cartoon drawing', 'a drawing'),
+ ('girl girl', 'girl'),
+]
+
+
+def clean_caption(caption):
+ for rf, rt in CAPTION_REPLACEMENTS:
+ replaced = True
+ while replaced:
+ bef = caption
+ caption = caption.replace(rf, rt)
+ replaced = bef != caption
+ return caption
+
+
+def main(args):
+ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
+ print(f"found {len(image_paths)} images.")
+
+ if os.path.exists(args.in_json):
+ print(f"loading existing metadata: {args.in_json}")
+ with open(args.in_json, "rt", encoding='utf-8') as f:
+ metadata = json.load(f)
+ else:
+ print("no metadata / メタデータファイルがありません")
+ return
+
+ print("cleaning captions and tags.")
+ for image_path in tqdm(image_paths):
+ tags_path = os.path.splitext(image_path)[0] + '.txt'
+ with open(tags_path, "rt", encoding='utf-8') as f:
+ tags = f.readlines()[0].strip()
+
+ image_key = os.path.splitext(os.path.basename(image_path))[0]
+ if image_key not in metadata:
+ print(f"image not in metadata / メタデータに画像がありません: {image_path}")
+ return
+
+ tags = metadata[image_key].get('tags')
+ if tags is None:
+ print(f"image does not have tags / メタデータにタグがありません: {image_path}")
+ else:
+ metadata[image_key]['tags'] = clean_tags(image_key, tags)
+
+ caption = metadata[image_key].get('caption')
+ if caption is None:
+ print(f"image does not have caption / メタデータにキャプションがありません: {image_path}")
+ else:
+ metadata[image_key]['caption'] = clean_caption(caption)
+
+ # metadataを書き出して終わり
+ print(f"writing metadata: {args.out_json}")
+ with open(args.out_json, "wt", encoding='utf-8') as f:
+ json.dump(metadata, f, indent=2)
+ print("done!")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
+ # parser.add_argument("--debug", action="store_true", help="debug mode")
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip b/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip
deleted file mode 100644
index 1dfb122f..00000000
Binary files a/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip and /dev/null differ
diff --git a/diffuser_fine_tuning/diffusers_fine_tuning_v2.zip b/diffuser_fine_tuning/diffusers_fine_tuning_v2.zip
deleted file mode 100644
index e027a3d1..00000000
Binary files a/diffuser_fine_tuning/diffusers_fine_tuning_v2.zip and /dev/null differ
diff --git a/diffuser_fine_tuning/diffusers_fine_tuning_v3.zip b/diffuser_fine_tuning/diffusers_fine_tuning_v3.zip
deleted file mode 100644
index 1eb74b5a..00000000
Binary files a/diffuser_fine_tuning/diffusers_fine_tuning_v3.zip and /dev/null differ
diff --git a/diffuser_fine_tuning/diffusers_fine_tuning_v4.zip b/diffuser_fine_tuning/diffusers_fine_tuning_v4.zip
deleted file mode 100644
index f846184e..00000000
Binary files a/diffuser_fine_tuning/diffusers_fine_tuning_v4.zip and /dev/null differ
diff --git a/diffuser_fine_tuning/fine_tune.py b/diffuser_fine_tuning/fine_tune.py
new file mode 100644
index 00000000..e17d1bd9
--- /dev/null
+++ b/diffuser_fine_tuning/fine_tune.py
@@ -0,0 +1,1005 @@
+# v2: select precision for saved checkpoint
+# v3: add logging for tensorboard, fix to shuffle=False in DataLoader (shuffling is in dataset)
+# v4: support SD2.0, add lr scheduler options, supports save_every_n_epochs and save_state for DiffUsers model
+# v5: refactor to use model_util, support safetensors, add settings to use Diffusers' xformers, add log prefix
+
+
+# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
+# License:
+# Copyright 2022 Kohya S. @kohya_ss
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# License of included scripts:
+
+# Diffusers: ASL 2.0 https://github.com/huggingface/diffusers/blob/main/LICENSE
+
+# Memory efficient attention:
+# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
+# MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
+
+import argparse
+import math
+import os
+import random
+import json
+import importlib
+import time
+
+from tqdm import tqdm
+import torch
+from accelerate import Accelerator
+from accelerate.utils import set_seed
+from transformers import CLIPTokenizer
+import diffusers
+from diffusers import DDPMScheduler, StableDiffusionPipeline
+import numpy as np
+from einops import rearrange
+from torch import einsum
+
+import model_util
+
+# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
+TOKENIZER_PATH = "openai/clip-vit-large-patch14"
+V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
+
+# checkpointファイル名
+EPOCH_STATE_NAME = "epoch-{:06d}-state"
+LAST_STATE_NAME = "last-state"
+
+LAST_DIFFUSERS_DIR_NAME = "last"
+EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
+
+
+def collate_fn(examples):
+ return examples[0]
+
+
+class FineTuningDataset(torch.utils.data.Dataset):
+ def __init__(self, metadata, train_data_dir, batch_size, tokenizer, max_token_length, shuffle_caption, dataset_repeats, debug) -> None:
+ super().__init__()
+
+ self.metadata = metadata
+ self.train_data_dir = train_data_dir
+ self.batch_size = batch_size
+ self.tokenizer: CLIPTokenizer = tokenizer
+ self.max_token_length = max_token_length
+ self.shuffle_caption = shuffle_caption
+ self.debug = debug
+
+ self.tokenizer_max_length = self.tokenizer.model_max_length if max_token_length is None else max_token_length + 2
+
+ print("make buckets")
+
+ # 最初に数を数える
+ self.bucket_resos = set()
+ for img_md in metadata.values():
+ if 'train_resolution' in img_md:
+ self.bucket_resos.add(tuple(img_md['train_resolution']))
+ self.bucket_resos = list(self.bucket_resos)
+ self.bucket_resos.sort()
+ print(f"number of buckets: {len(self.bucket_resos)}")
+
+ reso_to_index = {}
+ for i, reso in enumerate(self.bucket_resos):
+ reso_to_index[reso] = i
+
+ # bucketに割り当てていく
+ self.buckets = [[] for _ in range(len(self.bucket_resos))]
+ n = 1 if dataset_repeats is None else dataset_repeats
+ images_count = 0
+ for image_key, img_md in metadata.items():
+ if 'train_resolution' not in img_md:
+ continue
+ if not os.path.exists(os.path.join(self.train_data_dir, image_key + '.npz')):
+ continue
+
+ reso = tuple(img_md['train_resolution'])
+ for _ in range(n):
+ self.buckets[reso_to_index[reso]].append(image_key)
+ images_count += n
+
+ # 参照用indexを作る
+ self.buckets_indices = []
+ for bucket_index, bucket in enumerate(self.buckets):
+ batch_count = int(math.ceil(len(bucket) / self.batch_size))
+ for batch_index in range(batch_count):
+ self.buckets_indices.append((bucket_index, batch_index))
+
+ self.shuffle_buckets()
+ self._length = len(self.buckets_indices)
+ self.images_count = images_count
+
+ def show_buckets(self):
+ for i, (reso, bucket) in enumerate(zip(self.bucket_resos, self.buckets)):
+ print(f"bucket {i}: resolution {reso}, count: {len(bucket)}")
+
+ def shuffle_buckets(self):
+ random.shuffle(self.buckets_indices)
+ for bucket in self.buckets:
+ random.shuffle(bucket)
+
+ def load_latent(self, image_key):
+ return np.load(os.path.join(self.train_data_dir, image_key + '.npz'))['arr_0']
+
+ def __len__(self):
+ return self._length
+
+ def __getitem__(self, index):
+ if index == 0:
+ self.shuffle_buckets()
+
+ bucket = self.buckets[self.buckets_indices[index][0]]
+ image_index = self.buckets_indices[index][1] * self.batch_size
+
+ input_ids_list = []
+ latents_list = []
+ captions = []
+ for image_key in bucket[image_index:image_index + self.batch_size]:
+ img_md = self.metadata[image_key]
+ caption = img_md.get('caption')
+ tags = img_md.get('tags')
+
+ if caption is None:
+ caption = tags
+ elif tags is not None and len(tags) > 0:
+ caption = caption + ', ' + tags
+ assert caption is not None and len(caption) > 0, f"caption or tag is required / キャプションまたはタグは必須です:{image_key}"
+
+ latents = self.load_latent(image_key)
+
+ if self.shuffle_caption:
+ tokens = caption.strip().split(",")
+ random.shuffle(tokens)
+ caption = ",".join(tokens).strip()
+
+ captions.append(caption)
+
+ input_ids = self.tokenizer(caption, padding="max_length", truncation=True,
+ max_length=self.tokenizer_max_length, return_tensors="pt").input_ids
+
+ if self.tokenizer_max_length > self.tokenizer.model_max_length:
+ input_ids = input_ids.squeeze(0)
+ iids_list = []
+ if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id:
+ # v1
+ # 77以上の時は " .... " でトータル227とかになっているので、"..."の三連に変換する
+ # 1111氏のやつは , で区切る、とかしているようだが とりあえず単純に
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2): # (1, 152, 75)
+ ids_chunk = (input_ids[0].unsqueeze(0),
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
+ input_ids[-1].unsqueeze(0))
+ ids_chunk = torch.cat(ids_chunk)
+ iids_list.append(ids_chunk)
+ else:
+ # v2
+ # 77以上の時は " .... ..." でトータル227とかになっているので、"... ..."の三連に変換する
+ for i in range(1, self.tokenizer_max_length - self.tokenizer.model_max_length + 2, self.tokenizer.model_max_length - 2):
+ ids_chunk = (input_ids[0].unsqueeze(0), # BOS
+ input_ids[i:i + self.tokenizer.model_max_length - 2],
+ input_ids[-1].unsqueeze(0)) # PAD or EOS
+ ids_chunk = torch.cat(ids_chunk)
+
+ # 末尾が または の場合は、何もしなくてよい
+ # 末尾が x の場合は末尾を に変える(x なら結果的に変化なし)
+ if ids_chunk[-2] != self.tokenizer.eos_token_id and ids_chunk[-2] != self.tokenizer.pad_token_id:
+ ids_chunk[-1] = self.tokenizer.eos_token_id
+ # 先頭が ... の場合は ... に変える
+ if ids_chunk[1] == self.tokenizer.pad_token_id:
+ ids_chunk[1] = self.tokenizer.eos_token_id
+
+ iids_list.append(ids_chunk)
+
+ input_ids = torch.stack(iids_list) # 3,77
+
+ input_ids_list.append(input_ids)
+ latents_list.append(torch.FloatTensor(latents))
+
+ example = {}
+ example['input_ids'] = torch.stack(input_ids_list)
+ example['latents'] = torch.stack(latents_list)
+ if self.debug:
+ example['image_keys'] = bucket[image_index:image_index + self.batch_size]
+ example['captions'] = captions
+ return example
+
+
+def save_hypernetwork(output_file, hypernetwork):
+ state_dict = hypernetwork.get_state_dict()
+ torch.save(state_dict, output_file)
+
+
+def train(args):
+ fine_tuning = args.hypernetwork_module is None # fine tuning or hypernetwork training
+
+ # その他のオプション設定を確認する
+ if args.v_parameterization and not args.v2:
+ print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
+ if args.v2 and args.clip_skip is not None:
+ print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
+
+ # モデル形式のオプション設定を確認する
+ # v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した
+ use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
+
+ # 乱数系列を初期化する
+ if args.seed is not None:
+ set_seed(args.seed)
+
+ # メタデータを読み込む
+ if os.path.exists(args.in_json):
+ print(f"loading existing metadata: {args.in_json}")
+ with open(args.in_json, "rt", encoding='utf-8') as f:
+ metadata = json.load(f)
+ else:
+ print(f"no metadata / メタデータファイルがありません: {args.in_json}")
+ return
+
+ # tokenizerを読み込む
+ print("prepare tokenizer")
+ if args.v2:
+ tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
+ else:
+ tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
+
+ if args.max_token_length is not None:
+ print(f"update token length: {args.max_token_length}")
+
+ # datasetを用意する
+ print("prepare dataset")
+ train_dataset = FineTuningDataset(metadata, args.train_data_dir, args.train_batch_size,
+ tokenizer, args.max_token_length, args.shuffle_caption, args.dataset_repeats, args.debug_dataset)
+
+ print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
+ print(f"Total images / 画像数: {train_dataset.images_count}")
+ if args.debug_dataset:
+ train_dataset.show_buckets()
+ i = 0
+ for example in train_dataset:
+ print(f"image: {example['image_keys']}")
+ print(f"captions: {example['captions']}")
+ print(f"latents: {example['latents'].shape}")
+ print(f"input_ids: {example['input_ids'].shape}")
+ print(example['input_ids'])
+ i += 1
+ if i >= 8:
+ break
+ return
+
+ # acceleratorを準備する
+ print("prepare accelerator")
+ if args.logging_dir is None:
+ log_with = None
+ logging_dir = None
+ else:
+ log_with = "tensorboard"
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps,
+ mixed_precision=args.mixed_precision, log_with=log_with, logging_dir=logging_dir)
+
+ # mixed precisionに対応した型を用意しておき適宜castする
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ save_dtype = None
+ if args.save_precision == "fp16":
+ save_dtype = torch.float16
+ elif args.save_precision == "bf16":
+ save_dtype = torch.bfloat16
+ elif args.save_precision == "float":
+ save_dtype = torch.float32
+
+ # モデルを読み込む
+ if use_stable_diffusion_format:
+ print("load StableDiffusion checkpoint")
+ text_encoder, _, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
+ else:
+ print("load Diffusers pretrained models")
+ pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
+ # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
+ text_encoder = pipe.text_encoder
+ unet = pipe.unet
+ del pipe
+
+ # Diffusers版のxformers使用フラグを設定する関数
+ def set_diffusers_xformers_flag(model, valid):
+ # model.set_use_memory_efficient_attention_xformers(valid) # 次のリリースでなくなりそう
+ # pipeが自動で再帰的にset_use_memory_efficient_attention_xformersを探すんだって(;´Д`)
+ # U-Netだけ使う時にはどうすればいいのか……仕方ないからコピって使うか
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_use_memory_efficient_attention_xformers method
+ # gets the message
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
+ module.set_use_memory_efficient_attention_xformers(valid)
+
+ for child in module.children():
+ fn_recursive_set_mem_eff(child)
+
+ fn_recursive_set_mem_eff(model)
+
+ # モデルに xformers とか memory efficient attention を組み込む
+ if args.diffusers_xformers:
+ print("Use xformers by Diffusers")
+ set_diffusers_xformers_flag(unet, True)
+ else:
+ # Windows版のxformersはfloatで学習できないのでxformersを使わない設定も可能にしておく必要がある
+ print("Disable Diffusers' xformers")
+ set_diffusers_xformers_flag(unet, False)
+ replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
+
+ if not fine_tuning:
+ # Hypernetwork
+ print("import hypernetwork module:", args.hypernetwork_module)
+ hyp_module = importlib.import_module(args.hypernetwork_module)
+
+ hypernetwork = hyp_module.Hypernetwork()
+
+ if args.hypernetwork_weights is not None:
+ print("load hypernetwork weights from:", args.hypernetwork_weights)
+ hyp_sd = torch.load(args.hypernetwork_weights, map_location='cpu')
+ success = hypernetwork.load_from_state_dict(hyp_sd)
+ assert success, "hypernetwork weights loading failed."
+
+ print("apply hypernetwork")
+ hypernetwork.apply_to_diffusers(None, text_encoder, unet)
+
+ # 学習を準備する:モデルを適切な状態にする
+ training_models = []
+ if fine_tuning:
+ if args.gradient_checkpointing:
+ unet.enable_gradient_checkpointing()
+ training_models.append(unet)
+
+ if args.train_text_encoder:
+ print("enable text encoder training")
+ if args.gradient_checkpointing:
+ text_encoder.gradient_checkpointing_enable()
+ training_models.append(text_encoder)
+ else:
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.requires_grad_(False) # text encoderは学習しない
+ text_encoder.eval()
+ else:
+ unet.to(accelerator.device) # , dtype=weight_dtype) # dtypeを指定すると学習できない
+ unet.requires_grad_(False)
+ unet.eval()
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
+ text_encoder.requires_grad_(False)
+ text_encoder.eval()
+ training_models.append(hypernetwork)
+
+ for m in training_models:
+ m.requires_grad_(True)
+ params = []
+ for m in training_models:
+ params.extend(m.parameters())
+ params_to_optimize = params
+
+ # 学習に必要なクラスを準備する
+ print("prepare optimizer, data loader etc.")
+
+ # 8-bit Adamを使う
+ if args.use_8bit_adam:
+ try:
+ import bitsandbytes as bnb
+ except ImportError:
+ raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
+ print("use 8-bit Adam optimizer")
+ optimizer_class = bnb.optim.AdamW8bit
+ else:
+ optimizer_class = torch.optim.AdamW
+
+ # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
+ optimizer = optimizer_class(params_to_optimize, lr=args.learning_rate)
+
+ # dataloaderを準備する
+ # DataLoaderのプロセス数:0はメインプロセスになる
+ n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
+ train_dataloader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
+
+ # lr schedulerを用意する
+ lr_scheduler = diffusers.optimization.get_scheduler(
+ args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps)
+
+ # acceleratorがなんかよろしくやってくれるらしい
+ if fine_tuning:
+ if args.train_text_encoder:
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
+ else:
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
+ else:
+ unet, hypernetwork, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
+ unet, hypernetwork, optimizer, train_dataloader, lr_scheduler)
+
+ # TODO accelerateのconfigに指定した型とオプション指定の型とをチェックして異なれば警告を出す
+
+ # resumeする
+ if args.resume is not None:
+ print(f"resume training from state: {args.resume}")
+ accelerator.load_state(args.resume)
+
+ # epoch数を計算する
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
+ num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
+
+ # 学習する
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
+ print("running training / 学習開始")
+ print(f" num examples / サンプル数: {train_dataset.images_count}")
+ print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
+ print(f" num epochs / epoch数: {num_train_epochs}")
+ print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
+ print(f" total train batch size (with parallel & distributed & accumulation) / 総バッチサイズ(並列学習、勾配合計含む): {total_batch_size}")
+ print(f" gradient ccumulation steps / 勾配を合計するステップ数 = {args.gradient_accumulation_steps}")
+ print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
+
+ progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
+ global_step = 0
+
+ # v4で更新:clip_sample=Falseに
+ # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
+ # 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ
+ # よくソースを見たら学習時は関係ないや(;'∀')
+ noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
+ num_train_timesteps=1000, clip_sample=False)
+
+ if accelerator.is_main_process:
+ accelerator.init_trackers("finetuning" if fine_tuning else "hypernetwork")
+
+ # 以下 train_dreambooth.py からほぼコピペ
+ for epoch in range(num_train_epochs):
+ print(f"epoch {epoch+1}/{num_train_epochs}")
+ for m in training_models:
+ m.train()
+
+ loss_total = 0
+ for step, batch in enumerate(train_dataloader):
+ with accelerator.accumulate(training_models[0]): # 複数モデルに対応していない模様だがとりあえずこうしておく
+ latents = batch["latents"].to(accelerator.device)
+ latents = latents * 0.18215
+ b_size = latents.shape[0]
+
+ # with torch.no_grad():
+ with torch.set_grad_enabled(args.train_text_encoder):
+ # Get the text embedding for conditioning
+ input_ids = batch["input_ids"].to(accelerator.device)
+ input_ids = input_ids.reshape((-1, tokenizer.model_max_length)) # batch_size*3, 77
+
+ if args.clip_skip is None:
+ encoder_hidden_states = text_encoder(input_ids)[0]
+ else:
+ enc_out = text_encoder(input_ids, output_hidden_states=True, return_dict=True)
+ encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
+ encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
+
+ # bs*3, 77, 768 or 1024
+ encoder_hidden_states = encoder_hidden_states.reshape((b_size, -1, encoder_hidden_states.shape[-1]))
+
+ if args.max_token_length is not None:
+ if args.v2:
+ # v2: ... ... の三連を ... ... へ戻す 正直この実装でいいのかわからん
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] #
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
+ chunk = encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2] # の後から 最後の前まで
+ if i > 0:
+ for j in range(len(chunk)):
+ if input_ids[j, 1] == tokenizer.eos_token: # 空、つまり ...のパターン
+ chunk[j, 0] = chunk[j, 1] # 次の の値をコピーする
+ states_list.append(chunk) # の後から の前まで
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) # か のどちらか
+ encoder_hidden_states = torch.cat(states_list, dim=1)
+ else:
+ # v1: ... の三連を ... へ戻す
+ states_list = [encoder_hidden_states[:, 0].unsqueeze(1)] #
+ for i in range(1, args.max_token_length, tokenizer.model_max_length):
+ states_list.append(encoder_hidden_states[:, i:i + tokenizer.model_max_length - 2]) # の後から の前まで
+ states_list.append(encoder_hidden_states[:, -1].unsqueeze(1)) #
+ encoder_hidden_states = torch.cat(states_list, dim=1)
+
+ # Sample noise that we'll add to the latents
+ noise = torch.randn_like(latents, device=latents.device)
+
+ # Sample a random timestep for each image
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
+ timesteps = timesteps.long()
+
+ # Add noise to the latents according to the noise magnitude at each timestep
+ # (this is the forward diffusion process)
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
+
+ # Predict the noise residual
+ noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
+
+ if args.v_parameterization:
+ # v-parameterization training
+
+ # 11/29現在v predictionのコードがDiffusersにcommitされたがリリースされていないので独自コードを使う
+ # 実装の中身は同じ模様
+
+ # こうしたい:
+ # target = noise_scheduler.get_v(latents, noise, timesteps)
+
+ # StabilityAiのddpm.pyのコード:
+ # elif self.parameterization == "v":
+ # target = self.get_v(x_start, noise, t)
+ # ...
+ # def get_v(self, x, noise, t):
+ # return (
+ # extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
+ # extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
+ # )
+
+ # scheduling_ddim.pyのコード:
+ # elif self.config.prediction_type == "v_prediction":
+ # pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ # # predict V
+ # model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+
+ # これでいいかな?:
+ alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
+ beta_prod_t = 1 - alpha_prod_t
+ alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
+ beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
+ target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
+ else:
+ target = noise
+
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
+
+ accelerator.backward(loss)
+ if accelerator.sync_gradients:
+ params_to_clip = []
+ for m in training_models:
+ params_to_clip.extend(m.parameters())
+ accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
+
+ optimizer.step()
+ lr_scheduler.step()
+ optimizer.zero_grad(set_to_none=True)
+
+ # Checks if the accelerator has performed an optimization step behind the scenes
+ if accelerator.sync_gradients:
+ progress_bar.update(1)
+ global_step += 1
+
+ current_loss = loss.detach().item() # 平均なのでbatch sizeは関係ないはず
+ if args.logging_dir is not None:
+ logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
+ accelerator.log(logs, step=global_step)
+
+ loss_total += current_loss
+ avr_loss = loss_total / (step+1)
+ logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+
+ if global_step >= args.max_train_steps:
+ break
+
+ if args.logging_dir is not None:
+ logs = {"epoch_loss": loss_total / len(train_dataloader)}
+ accelerator.log(logs, step=epoch+1)
+
+ accelerator.wait_for_everyone()
+
+ if args.save_every_n_epochs is not None:
+ if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
+ print("saving checkpoint.")
+ os.makedirs(args.output_dir, exist_ok=True)
+ ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1))
+
+ if fine_tuning:
+ if use_stable_diffusion_format:
+ model_util.save_stable_diffusion_checkpoint(
+ args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
+ args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
+ else:
+ out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
+ os.makedirs(out_dir, exist_ok=True)
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
+ accelerator.unwrap_model(unet), args.pretrained_model_name_or_path)
+ else:
+ save_hypernetwork(ckpt_file, accelerator.unwrap_model(hypernetwork))
+
+ if args.save_state:
+ print("saving state.")
+ accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
+
+ is_main_process = accelerator.is_main_process
+ if is_main_process:
+ if fine_tuning:
+ unet = accelerator.unwrap_model(unet)
+ text_encoder = accelerator.unwrap_model(text_encoder)
+ else:
+ hypernetwork = accelerator.unwrap_model(hypernetwork)
+
+ accelerator.end_training()
+
+ if args.save_state:
+ print("saving last state.")
+ accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
+
+ del accelerator # この後メモリを使うのでこれは消す
+
+ if is_main_process:
+ os.makedirs(args.output_dir, exist_ok=True)
+ if fine_tuning:
+ if use_stable_diffusion_format:
+ ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors))
+ print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
+ model_util.save_stable_diffusion_checkpoint(
+ args.v2, ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
+ else:
+ # Create the pipeline using using the trained modules and save it.
+ print(f"save trained model as Diffusers to {args.output_dir}")
+ out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
+ os.makedirs(out_dir, exist_ok=True)
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path)
+ else:
+ ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors))
+ print(f"save trained model to {ckpt_file}")
+ save_hypernetwork(ckpt_file, hypernetwork)
+
+ print("model saved.")
+
+
+# region モジュール入れ替え部
+"""
+高速化のためのモジュール入れ替え
+"""
+
+# FlashAttentionを使うCrossAttention
+# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
+# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
+
+# constants
+
+EPSILON = 1e-6
+
+# helper functions
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+# flash attention forwards and backwards
+
+# https://arxiv.org/abs/2205.14135
+
+
+class FlashAttentionFunction(torch.autograd.function.Function):
+ @ staticmethod
+ @ torch.no_grad()
+ def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
+ """ Algorithm 2 in the paper """
+
+ device = q.device
+ dtype = q.dtype
+ max_neg_value = -torch.finfo(q.dtype).max
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+ o = torch.zeros_like(q)
+ all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
+ all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
+
+ scale = (q.shape[-1] ** -0.5)
+
+ if not exists(mask):
+ mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
+ else:
+ mask = rearrange(mask, 'b n -> b 1 1 n')
+ mask = mask.split(q_bucket_size, dim=-1)
+
+ row_splits = zip(
+ q.split(q_bucket_size, dim=-2),
+ o.split(q_bucket_size, dim=-2),
+ mask,
+ all_row_sums.split(q_bucket_size, dim=-2),
+ all_row_maxes.split(q_bucket_size, dim=-2),
+ )
+
+ for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
+ q_start_index = ind * q_bucket_size - qk_len_diff
+
+ col_splits = zip(
+ k.split(k_bucket_size, dim=-2),
+ v.split(k_bucket_size, dim=-2),
+ )
+
+ for k_ind, (kc, vc) in enumerate(col_splits):
+ k_start_index = k_ind * k_bucket_size
+
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
+
+ if exists(row_mask):
+ attn_weights.masked_fill_(~row_mask, max_neg_value)
+
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
+ device=device).triu(q_start_index - k_start_index + 1)
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+ block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
+ attn_weights -= block_row_maxes
+ exp_weights = torch.exp(attn_weights)
+
+ if exists(row_mask):
+ exp_weights.masked_fill_(~row_mask, 0.)
+
+ block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
+
+ new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
+
+ exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
+
+ exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
+ exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
+
+ new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
+
+ oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
+
+ row_maxes.copy_(new_row_maxes)
+ row_sums.copy_(new_row_sums)
+
+ ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
+ ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
+
+ return o
+
+ @ staticmethod
+ @ torch.no_grad()
+ def backward(ctx, do):
+ """ Algorithm 4 in the paper """
+
+ causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
+ q, k, v, o, l, m = ctx.saved_tensors
+
+ device = q.device
+
+ max_neg_value = -torch.finfo(q.dtype).max
+ qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
+
+ dq = torch.zeros_like(q)
+ dk = torch.zeros_like(k)
+ dv = torch.zeros_like(v)
+
+ row_splits = zip(
+ q.split(q_bucket_size, dim=-2),
+ o.split(q_bucket_size, dim=-2),
+ do.split(q_bucket_size, dim=-2),
+ mask,
+ l.split(q_bucket_size, dim=-2),
+ m.split(q_bucket_size, dim=-2),
+ dq.split(q_bucket_size, dim=-2)
+ )
+
+ for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
+ q_start_index = ind * q_bucket_size - qk_len_diff
+
+ col_splits = zip(
+ k.split(k_bucket_size, dim=-2),
+ v.split(k_bucket_size, dim=-2),
+ dk.split(k_bucket_size, dim=-2),
+ dv.split(k_bucket_size, dim=-2),
+ )
+
+ for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
+ k_start_index = k_ind * k_bucket_size
+
+ attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
+
+ if causal and q_start_index < (k_start_index + k_bucket_size - 1):
+ causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
+ device=device).triu(q_start_index - k_start_index + 1)
+ attn_weights.masked_fill_(causal_mask, max_neg_value)
+
+ exp_attn_weights = torch.exp(attn_weights - mc)
+
+ if exists(row_mask):
+ exp_attn_weights.masked_fill_(~row_mask, 0.)
+
+ p = exp_attn_weights / lc
+
+ dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
+ dp = einsum('... i d, ... j d -> ... i j', doc, vc)
+
+ D = (doc * oc).sum(dim=-1, keepdims=True)
+ ds = p * scale * (dp - D)
+
+ dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
+ dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
+
+ dqc.add_(dq_chunk)
+ dkc.add_(dk_chunk)
+ dvc.add_(dv_chunk)
+
+ return dq, dk, dv, None, None, None, None
+
+
+def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
+ if mem_eff_attn:
+ replace_unet_cross_attn_to_memory_efficient()
+ elif xformers:
+ replace_unet_cross_attn_to_xformers()
+
+
+def replace_unet_cross_attn_to_memory_efficient():
+ print("Replace CrossAttention.forward to use FlashAttention (not xformers)")
+ flash_func = FlashAttentionFunction
+
+ def forward_flash_attn(self, x, context=None, mask=None):
+ q_bucket_size = 512
+ k_bucket_size = 1024
+
+ h = self.heads
+ q = self.to_q(x)
+
+ context = context if context is not None else x
+ context = context.to(x.dtype)
+
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
+ context_k, context_v = self.hypernetwork.forward(x, context)
+ context_k = context_k.to(x.dtype)
+ context_v = context_v.to(x.dtype)
+ else:
+ context_k = context
+ context_v = context
+
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
+ del context, x
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
+
+ out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
+
+ out = rearrange(out, 'b h n d -> b n (h d)')
+
+ # diffusers 0.6.0
+ if type(self.to_out) is torch.nn.Sequential:
+ return self.to_out(out)
+
+ # diffusers 0.7.0~ わざわざ変えるなよ (;´Д`)
+ out = self.to_out[0](out)
+ out = self.to_out[1](out)
+ return out
+
+ diffusers.models.attention.CrossAttention.forward = forward_flash_attn
+
+
+def replace_unet_cross_attn_to_xformers():
+ print("Replace CrossAttention.forward to use xformers")
+ try:
+ import xformers.ops
+ except ImportError:
+ raise ImportError("No xformers / xformersがインストールされていないようです")
+
+ def forward_xformers(self, x, context=None, mask=None):
+ h = self.heads
+ q_in = self.to_q(x)
+
+ context = default(context, x)
+ context = context.to(x.dtype)
+
+ if hasattr(self, 'hypernetwork') and self.hypernetwork is not None:
+ context_k, context_v = self.hypernetwork.forward(x, context)
+ context_k = context_k.to(x.dtype)
+ context_v = context_v.to(x.dtype)
+ else:
+ context_k = context
+ context_v = context
+
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
+ del q_in, k_in, v_in
+
+ q = q.contiguous()
+ k = k.contiguous()
+ v = v.contiguous()
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
+
+ out = rearrange(out, 'b n h d -> b n (h d)', h=h)
+
+ # diffusers 0.6.0
+ if type(self.to_out) is torch.nn.Sequential:
+ return self.to_out(out)
+
+ # diffusers 0.7.0~
+ out = self.to_out[0](out)
+ out = self.to_out[1](out)
+ return out
+
+ diffusers.models.attention.CrossAttention.forward = forward_xformers
+# endregion
+
+
+if __name__ == '__main__':
+ # torch.cuda.set_per_process_memory_fraction(0.48)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--v2", action='store_true',
+ help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
+ parser.add_argument("--v_parameterization", action='store_true',
+ help='enable v-parameterization training / v-parameterization学習を有効にする')
+ parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
+ help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
+ parser.add_argument("--in_json", type=str, default=None, help="metadata file to input / 読みこむメタデータファイル")
+ parser.add_argument("--shuffle_caption", action="store_true",
+ help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
+ parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("--dataset_repeats", type=int, default=None, help="num times to repeat dataset / 学習にデータセットを繰り返す回数")
+ parser.add_argument("--output_dir", type=str, default=None,
+ help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
+ parser.add_argument("--use_safetensors", action='store_true',
+ help="use safetensors format for StableDiffusion checkpoint / StableDiffusionのcheckpointをsafetensors形式で保存する")
+ parser.add_argument("--train_text_encoder", action="store_true", help="train text encoder / text encoderも学習する")
+ parser.add_argument("--hypernetwork_module", type=str, default=None,
+ help='train hypernetwork instead of fine tuning, module to use / fine tuningの代わりにHypernetworkの学習をする場合、そのモジュール')
+ parser.add_argument("--hypernetwork_weights", type=str, default=None,
+ help='hypernetwork weights to initialize for additional training / Hypernetworkの学習時に読み込む重み(Hypernetworkの追加学習)')
+ parser.add_argument("--save_every_n_epochs", type=int, default=None,
+ help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存する")
+ parser.add_argument("--save_state", action="store_true",
+ help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
+ parser.add_argument("--resume", type=str, default=None,
+ help="saved state to resume training / 学習再開するモデルのstate")
+ parser.add_argument("--max_token_length", type=int, default=None, choices=[None, 150, 225],
+ help="max token length of text encoder (default for 75, 150 or 225) / text encoderのトークンの最大長(未指定で75、150または225が指定可)")
+ parser.add_argument("--train_batch_size", type=int, default=1,
+ help="batch size for training / 学習時のバッチサイズ")
+ parser.add_argument("--use_8bit_adam", action="store_true",
+ help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
+ parser.add_argument("--mem_eff_attn", action="store_true",
+ help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
+ parser.add_argument("--xformers", action="store_true",
+ help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
+ parser.add_argument("--diffusers_xformers", action='store_true',
+ help='use xformers by diffusers (Hypernetworks doen\'t work) / Diffusersでxformersを使用する(Hypernetwork利用不可)')
+ parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
+ parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
+ parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
+ parser.add_argument("--gradient_checkpointing", action="store_true",
+ help="enable gradient checkpointing / grandient checkpointingを有効にする")
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=1,
+ help="Number of updates steps to accumulate before performing a backward/update pass / 学習時に逆伝播をする前に勾配を合計するステップ数")
+ parser.add_argument("--mixed_precision", type=str, default="no",
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
+ parser.add_argument("--save_precision", type=str, default=None,
+ choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
+ parser.add_argument("--clip_skip", type=int, default=None,
+ help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
+ parser.add_argument("--debug_dataset", action="store_true",
+ help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
+ parser.add_argument("--logging_dir", type=str, default=None,
+ help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
+ parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
+ parser.add_argument("--lr_scheduler", type=str, default="constant",
+ help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
+ parser.add_argument("--lr_warmup_steps", type=int, default=0,
+ help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
+
+ args = parser.parse_args()
+ train(args)
diff --git a/diffuser_fine_tuning/hypernetwork_nai.py b/diffuser_fine_tuning/hypernetwork_nai.py
new file mode 100644
index 00000000..dcaaa714
--- /dev/null
+++ b/diffuser_fine_tuning/hypernetwork_nai.py
@@ -0,0 +1,96 @@
+# NAI compatible
+
+import torch
+
+
+class HypernetworkModule(torch.nn.Module):
+ def __init__(self, dim, multiplier=1.0):
+ super().__init__()
+
+ linear1 = torch.nn.Linear(dim, dim * 2)
+ linear2 = torch.nn.Linear(dim * 2, dim)
+ linear1.weight.data.normal_(mean=0.0, std=0.01)
+ linear1.bias.data.zero_()
+ linear2.weight.data.normal_(mean=0.0, std=0.01)
+ linear2.bias.data.zero_()
+ linears = [linear1, linear2]
+
+ self.linear = torch.nn.Sequential(*linears)
+ self.multiplier = multiplier
+
+ def forward(self, x):
+ return x + self.linear(x) * self.multiplier
+
+
+class Hypernetwork(torch.nn.Module):
+ enable_sizes = [320, 640, 768, 1280]
+ # return self.modules[Hypernetwork.enable_sizes.index(size)]
+
+ def __init__(self, multiplier=1.0) -> None:
+ super().__init__()
+ self.modules = []
+ for size in Hypernetwork.enable_sizes:
+ self.modules.append((HypernetworkModule(size, multiplier), HypernetworkModule(size, multiplier)))
+ self.register_module(f"{size}_0", self.modules[-1][0])
+ self.register_module(f"{size}_1", self.modules[-1][1])
+
+ def apply_to_stable_diffusion(self, text_encoder, vae, unet):
+ blocks = unet.input_blocks + [unet.middle_block] + unet.output_blocks
+ for block in blocks:
+ for subblk in block:
+ if 'SpatialTransformer' in str(type(subblk)):
+ for tf_block in subblk.transformer_blocks:
+ for attn in [tf_block.attn1, tf_block.attn2]:
+ size = attn.context_dim
+ if size in Hypernetwork.enable_sizes:
+ attn.hypernetwork = self
+ else:
+ attn.hypernetwork = None
+
+ def apply_to_diffusers(self, text_encoder, vae, unet):
+ blocks = unet.down_blocks + [unet.mid_block] + unet.up_blocks
+ for block in blocks:
+ if hasattr(block, 'attentions'):
+ for subblk in block.attentions:
+ if 'SpatialTransformer' in str(type(subblk)) or 'Transformer2DModel' in str(type(subblk)): # 0.6.0 and 0.7~
+ for tf_block in subblk.transformer_blocks:
+ for attn in [tf_block.attn1, tf_block.attn2]:
+ size = attn.to_k.in_features
+ if size in Hypernetwork.enable_sizes:
+ attn.hypernetwork = self
+ else:
+ attn.hypernetwork = None
+ return True # TODO error checking
+
+ def forward(self, x, context):
+ size = context.shape[-1]
+ assert size in Hypernetwork.enable_sizes
+ module = self.modules[Hypernetwork.enable_sizes.index(size)]
+ return module[0].forward(context), module[1].forward(context)
+
+ def load_from_state_dict(self, state_dict):
+ # old ver to new ver
+ changes = {
+ 'linear1.bias': 'linear.0.bias',
+ 'linear1.weight': 'linear.0.weight',
+ 'linear2.bias': 'linear.1.bias',
+ 'linear2.weight': 'linear.1.weight',
+ }
+ for key_from, key_to in changes.items():
+ if key_from in state_dict:
+ state_dict[key_to] = state_dict[key_from]
+ del state_dict[key_from]
+
+ for size, sd in state_dict.items():
+ if type(size) == int:
+ self.modules[Hypernetwork.enable_sizes.index(size)][0].load_state_dict(sd[0], strict=True)
+ self.modules[Hypernetwork.enable_sizes.index(size)][1].load_state_dict(sd[1], strict=True)
+ return True
+
+ def get_state_dict(self):
+ state_dict = {}
+ for i, size in enumerate(Hypernetwork.enable_sizes):
+ sd0 = self.modules[i][0].state_dict()
+ sd1 = self.modules[i][1].state_dict()
+ state_dict[size] = [sd0, sd1]
+ return state_dict
diff --git a/diffuser_fine_tuning/make_captions.py b/diffuser_fine_tuning/make_captions.py
new file mode 100644
index 00000000..44f1e53b
--- /dev/null
+++ b/diffuser_fine_tuning/make_captions.py
@@ -0,0 +1,97 @@
+# このスクリプトのライセンスは、Apache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+import argparse
+import glob
+import os
+import json
+
+from PIL import Image
+from tqdm import tqdm
+import numpy as np
+import torch
+from torchvision import transforms
+from torchvision.transforms.functional import InterpolationMode
+from models.blip import blip_decoder
+# from Salesforce_BLIP.models.blip import blip_decoder
+
+DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+def main(args):
+ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
+ print(f"found {len(image_paths)} images.")
+
+ print(f"loading BLIP caption: {args.caption_weights}")
+ image_size = 384
+ model = blip_decoder(pretrained=args.caption_weights, image_size=image_size, vit='large')
+ model.eval()
+ model = model.to(DEVICE)
+ print("BLIP loaded")
+
+ # 正方形でいいのか? という気がするがソースがそうなので
+ transform = transforms.Compose([
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
+ transforms.ToTensor(),
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
+ ])
+
+ # captioningする
+ def run_batch(path_imgs):
+ imgs = torch.stack([im for _, im in path_imgs]).to(DEVICE)
+
+ with torch.no_grad():
+ if args.beam_search:
+ captions = model.generate(imgs, sample=False, num_beams=args.num_beams,
+ max_length=args.max_length, min_length=args.min_length)
+ else:
+ captions = model.generate(imgs, sample=True, top_p=args.top_p, max_length=args.max_length, min_length=args.min_length)
+
+ for (image_path, _), caption in zip(path_imgs, captions):
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
+ f.write(caption + "\n")
+ if args.debug:
+ print(image_path, caption)
+
+ b_imgs = []
+ for image_path in tqdm(image_paths):
+ raw_image = Image.open(image_path)
+ if raw_image.mode != "RGB":
+ print(f"convert image mode {raw_image.mode} to RGB: {image_path}")
+ raw_image = raw_image.convert("RGB")
+
+ image = transform(raw_image)
+ b_imgs.append((image_path, image))
+ if len(b_imgs) >= args.batch_size:
+ run_batch(b_imgs)
+ b_imgs.clear()
+ if len(b_imgs) > 0:
+ run_batch(b_imgs)
+
+ print("done!")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("caption_weights", type=str,
+ help="BLIP caption weights (model_large_caption.pth) / BLIP captionの重みファイル(model_large_caption.pth)")
+ parser.add_argument("--caption_extention", type=str, default=None,
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 出力されるキャプションファイルの拡張子")
+ parser.add_argument("--beam_search", action="store_true",
+ help="use beam search (default Nucleus sampling) / beam searchを使う(このオプション未指定時はNucleus sampling)")
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
+ parser.add_argument("--num_beams", type=int, default=1, help="num of beams in beam search /beam search時のビーム数(多いと精度が上がるが時間がかかる)")
+ parser.add_argument("--top_p", type=float, default=0.9, help="top_p in Nucleus sampling / Nucleus sampling時のtop_p")
+ parser.add_argument("--max_length", type=int, default=75, help="max length of caption / captionの最大長")
+ parser.add_argument("--min_length", type=int, default=5, help="min length of caption / captionの最小長")
+ parser.add_argument("--debug", action="store_true", help="debug mode")
+
+ args = parser.parse_args()
+
+ # スペルミスしていたオプションを復元する
+ if args.caption_extention is not None:
+ args.caption_extension = args.caption_extention
+
+ main(args)
diff --git a/diffuser_fine_tuning/merge_captions_to_metadata.py b/diffuser_fine_tuning/merge_captions_to_metadata.py
new file mode 100644
index 00000000..a50d2bdd
--- /dev/null
+++ b/diffuser_fine_tuning/merge_captions_to_metadata.py
@@ -0,0 +1,68 @@
+# このスクリプトのライセンスは、Apache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+import argparse
+import glob
+import os
+import json
+
+from tqdm import tqdm
+
+
+def main(args):
+ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
+ print(f"found {len(image_paths)} images.")
+
+ if args.in_json is not None:
+ print(f"loading existing metadata: {args.in_json}")
+ with open(args.in_json, "rt", encoding='utf-8') as f:
+ metadata = json.load(f)
+ print("captions for existing images will be overwritten / 既存の画像のキャプションは上書きされます")
+ else:
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
+ metadata = {}
+
+ print("merge caption texts to metadata json.")
+ for image_path in tqdm(image_paths):
+ caption_path = os.path.splitext(image_path)[0] + args.caption_extension
+ with open(caption_path, "rt", encoding='utf-8') as f:
+ caption = f.readlines()[0].strip()
+
+ image_key = os.path.splitext(os.path.basename(image_path))[0]
+ if image_key not in metadata:
+ # if args.verify_caption:
+ # print(f"image not in metadata / メタデータに画像がありません: {image_path}")
+ # return
+ metadata[image_key] = {}
+ # elif args.verify_caption and 'caption' not in metadata[image_key]:
+ # print(f"no caption in metadata / メタデータにcaptionがありません: {image_path}")
+ # return
+
+ metadata[image_key]['caption'] = caption
+ if args.debug:
+ print(image_key, caption)
+
+ # metadataを書き出して終わり
+ print(f"writing metadata: {args.out_json}")
+ with open(args.out_json, "wt", encoding='utf-8') as f:
+ json.dump(metadata, f, indent=2)
+ print("done!")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
+ parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
+ parser.add_argument("--caption_extention", type=str, default=None,
+ help="extension of caption file (for backward compatibility) / 読み込むキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
+ parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption file / 読み込むキャプションファイルの拡張子")
+ parser.add_argument("--debug", action="store_true", help="debug mode")
+
+ args = parser.parse_args()
+
+ # スペルミスしていたオプションを復元する
+ if args.caption_extention is not None:
+ args.caption_extension = args.caption_extention
+
+ main(args)
diff --git a/diffuser_fine_tuning/merge_dd_tags_to_metadata.py b/diffuser_fine_tuning/merge_dd_tags_to_metadata.py
new file mode 100644
index 00000000..6436e6ae
--- /dev/null
+++ b/diffuser_fine_tuning/merge_dd_tags_to_metadata.py
@@ -0,0 +1,61 @@
+# このスクリプトのライセンスは、Apache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+import argparse
+import glob
+import os
+import json
+
+from tqdm import tqdm
+
+
+def main(args):
+ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
+ print(f"found {len(image_paths)} images.")
+
+ if args.in_json is not None:
+ print(f"loading existing metadata: {args.in_json}")
+ with open(args.in_json, "rt", encoding='utf-8') as f:
+ metadata = json.load(f)
+ print("tags data for existing images will be overwritten / 既存の画像のタグは上書きされます")
+ else:
+ print("new metadata will be created / 新しいメタデータファイルが作成されます")
+ metadata = {}
+
+ print("merge tags to metadata json.")
+ for image_path in tqdm(image_paths):
+ tags_path = os.path.splitext(image_path)[0] + '.txt'
+ with open(tags_path, "rt", encoding='utf-8') as f:
+ tags = f.readlines()[0].strip()
+
+ image_key = os.path.splitext(os.path.basename(image_path))[0]
+ if image_key not in metadata:
+ # if args.verify_caption:
+ # print(f"image not in metadata / メタデータに画像がありません: {image_path}")
+ # return
+ metadata[image_key] = {}
+ # elif args.verify_caption and 'caption' not in metadata[image_key]:
+ # print(f"no caption in metadata / メタデータにcaptionがありません: {image_path}")
+ # return
+
+ metadata[image_key]['tags'] = tags
+ if args.debug:
+ print(image_key, tags)
+
+ # metadataを書き出して終わり
+ print(f"writing metadata: {args.out_json}")
+ with open(args.out_json, "wt", encoding='utf-8') as f:
+ json.dump(metadata, f, indent=2)
+ print("done!")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
+ parser.add_argument("--in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
+ # parser.add_argument("--verify_caption", action="store_true", help="verify caption exists / メタデータにすでにcaptionが存在することを確認する")
+ parser.add_argument("--debug", action="store_true", help="debug mode")
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffuser_fine_tuning/model_util.py b/diffuser_fine_tuning/model_util.py
new file mode 100644
index 00000000..74650bf4
--- /dev/null
+++ b/diffuser_fine_tuning/model_util.py
@@ -0,0 +1,1166 @@
+# v1: split from train_db_fixed.py.
+# v2: support safetensors
+
+import math
+import os
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
+from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from safetensors.torch import load_file, save_file
+
+# DiffUsers版StableDiffusionのモデルパラメータ
+NUM_TRAIN_TIMESTEPS = 1000
+BETA_START = 0.00085
+BETA_END = 0.0120
+
+UNET_PARAMS_MODEL_CHANNELS = 320
+UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
+UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
+UNET_PARAMS_IMAGE_SIZE = 32 # unused
+UNET_PARAMS_IN_CHANNELS = 4
+UNET_PARAMS_OUT_CHANNELS = 4
+UNET_PARAMS_NUM_RES_BLOCKS = 2
+UNET_PARAMS_CONTEXT_DIM = 768
+UNET_PARAMS_NUM_HEADS = 8
+
+VAE_PARAMS_Z_CHANNELS = 4
+VAE_PARAMS_RESOLUTION = 256
+VAE_PARAMS_IN_CHANNELS = 3
+VAE_PARAMS_OUT_CH = 3
+VAE_PARAMS_CH = 128
+VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
+VAE_PARAMS_NUM_RES_BLOCKS = 2
+
+# V2
+V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
+V2_UNET_PARAMS_CONTEXT_DIM = 1024
+
+
+# region StableDiffusion->Diffusersの変換コード
+# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "query.weight")
+ new_item = new_item.replace("q.bias", "query.bias")
+
+ new_item = new_item.replace("k.weight", "key.weight")
+ new_item = new_item.replace("k.bias", "key.bias")
+
+ new_item = new_item.replace("v.weight", "value.weight")
+ new_item = new_item.replace("v.bias", "value.bias")
+
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming
+ to them. It splits attention layers, and takes into account additional replacements
+ that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ if "proj_attn.weight" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def linear_transformer_to_conv(checkpoint):
+ keys = list(checkpoint.keys())
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in tf_keys:
+ if checkpoint[key].ndim == 2:
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
+
+
+def convert_ldm_unet_checkpoint(v2, checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ unet_key = "model.diffusion_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ # オリジナル:
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
+
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
+ for l in output_block_list.values():
+ l.sort()
+
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
+ if v2:
+ linear_transformer_to_conv(new_checkpoint)
+
+ return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ vae_state_dict = {}
+ vae_key = "first_stage_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+ # if len(vae_state_dict) == 0:
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
+ # vae_state_dict = checkpoint
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
+
+
+def create_unet_diffusers_config(v2):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ # unet_params = original_config.model.params.unet_config.params
+
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ config = dict(
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
+ in_channels=UNET_PARAMS_IN_CHANNELS,
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
+ down_block_types=tuple(down_block_types),
+ up_block_types=tuple(up_block_types),
+ block_out_channels=tuple(block_out_channels),
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
+ )
+
+ return config
+
+
+def create_vae_diffusers_config():
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = dict(
+ sample_size=VAE_PARAMS_RESOLUTION,
+ in_channels=VAE_PARAMS_IN_CHANNELS,
+ out_channels=VAE_PARAMS_OUT_CH,
+ down_block_types=tuple(down_block_types),
+ up_block_types=tuple(up_block_types),
+ block_out_channels=tuple(block_out_channels),
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
+ )
+ return config
+
+
+def convert_ldm_clip_checkpoint_v1(checkpoint):
+ keys = list(checkpoint.keys())
+ text_model_dict = {}
+ for key in keys:
+ if key.startswith("cond_stage_model.transformer"):
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
+ return text_model_dict
+
+
+def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
+ # 嫌になるくらい違うぞ!
+ def convert_key(key):
+ if not key.startswith("cond_stage_model"):
+ return None
+
+ # common conversion
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
+ key = key.replace("cond_stage_model.model.", "text_model.")
+
+ if "resblocks" in key:
+ # resblocks conversion
+ key = key.replace(".resblocks.", ".layers.")
+ if ".ln_" in key:
+ key = key.replace(".ln_", ".layer_norm")
+ elif ".mlp." in key:
+ key = key.replace(".c_fc.", ".fc1.")
+ key = key.replace(".c_proj.", ".fc2.")
+ elif '.attn.out_proj' in key:
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
+ elif '.attn.in_proj' in key:
+ key = None # 特殊なので後で処理する
+ else:
+ raise ValueError(f"unexpected key in SD: {key}")
+ elif '.positional_embedding' in key:
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
+ elif '.text_projection' in key:
+ key = None # 使われない???
+ elif '.logit_scale' in key:
+ key = None # 使われない???
+ elif '.token_embedding' in key:
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
+ elif '.ln_final' in key:
+ key = key.replace(".ln_final", ".final_layer_norm")
+ return key
+
+ keys = list(checkpoint.keys())
+ new_sd = {}
+ for key in keys:
+ # remove resblocks 23
+ if '.resblocks.23.' in key:
+ continue
+ new_key = convert_key(key)
+ if new_key is None:
+ continue
+ new_sd[new_key] = checkpoint[key]
+
+ # attnの変換
+ for key in keys:
+ if '.resblocks.23.' in key:
+ continue
+ if '.resblocks' in key and '.attn.in_proj_' in key:
+ # 三つに分割
+ values = torch.chunk(checkpoint[key], 3)
+
+ key_suffix = ".weight" if "weight" in key else ".bias"
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
+ key_pfx = key_pfx.replace("_weight", "")
+ key_pfx = key_pfx.replace("_bias", "")
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
+
+ # position_idsの追加
+ new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
+ return new_sd
+
+# endregion
+
+
+# region Diffusers->StableDiffusion の変換コード
+# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
+
+def conv_transformer_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in tf_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+
+
+def convert_unet_state_dict_to_sd(v2, unet_state_dict):
+ unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+ ]
+
+ unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+ ]
+
+ unet_conversion_map_layer = []
+ for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ hf_mid_atn_prefix = "mid_block.attentions.0."
+ sd_mid_atn_prefix = "middle_block.1."
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+ for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+
+ if v2:
+ conv_transformer_to_linear(new_state_dict)
+
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+ ]
+
+ for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+ # this part accounts for mid blocks in both the encoder and the decoder
+ for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+ ]
+
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ # print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+
+ return new_state_dict
+
+
+# endregion
+
+# region 自作のモデル読み書き
+
+def is_safetensors(path):
+ return os.path.splitext(path)[1].lower() == '.safetensors'
+
+
+def load_checkpoint_with_text_encoder_conversion(ckpt_path):
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
+ ]
+
+ if is_safetensors(ckpt_path):
+ checkpoint = None
+ state_dict = load_file(ckpt_path, "cpu")
+ else:
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ checkpoint = None
+
+ key_reps = []
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
+ for key in state_dict.keys():
+ if key.startswith(rep_from):
+ new_key = rep_to + key[len(rep_from):]
+ key_reps.append((key, new_key))
+
+ for key, new_key in key_reps:
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+
+ return checkpoint, state_dict
+
+
+# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
+def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
+ if dtype is not None:
+ for k, v in state_dict.items():
+ if type(v) is torch.Tensor:
+ state_dict[k] = v.to(dtype)
+
+ # Convert the UNet2DConditionModel model.
+ unet_config = create_unet_diffusers_config(v2)
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
+
+ unet = UNet2DConditionModel(**unet_config)
+ info = unet.load_state_dict(converted_unet_checkpoint)
+ print("loading u-net:", info)
+
+ # Convert the VAE model.
+ vae_config = create_vae_diffusers_config()
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ info = vae.load_state_dict(converted_vae_checkpoint)
+ print("loadint vae:", info)
+
+ # convert text_model
+ if v2:
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
+ cfg = CLIPTextConfig(
+ vocab_size=49408,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_hidden_layers=23,
+ num_attention_heads=16,
+ max_position_embeddings=77,
+ hidden_act="gelu",
+ layer_norm_eps=1e-05,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ model_type="clip_text_model",
+ projection_dim=512,
+ torch_dtype="float32",
+ transformers_version="4.25.0.dev0",
+ )
+ text_model = CLIPTextModel._from_config(cfg)
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+ else:
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+ print("loading text encoder:", info)
+
+ return text_model, vae, unet
+
+
+def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
+ def convert_key(key):
+ # position_idsの除去
+ if ".position_ids" in key:
+ return None
+
+ # common
+ key = key.replace("text_model.encoder.", "transformer.")
+ key = key.replace("text_model.", "")
+ if "layers" in key:
+ # resblocks conversion
+ key = key.replace(".layers.", ".resblocks.")
+ if ".layer_norm" in key:
+ key = key.replace(".layer_norm", ".ln_")
+ elif ".mlp." in key:
+ key = key.replace(".fc1.", ".c_fc.")
+ key = key.replace(".fc2.", ".c_proj.")
+ elif '.self_attn.out_proj' in key:
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
+ elif '.self_attn.' in key:
+ key = None # 特殊なので後で処理する
+ else:
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
+ elif '.position_embedding' in key:
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
+ elif '.token_embedding' in key:
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
+ elif 'final_layer_norm' in key:
+ key = key.replace("final_layer_norm", "ln_final")
+ return key
+
+ keys = list(checkpoint.keys())
+ new_sd = {}
+ for key in keys:
+ new_key = convert_key(key)
+ if new_key is None:
+ continue
+ new_sd[new_key] = checkpoint[key]
+
+ # attnの変換
+ for key in keys:
+ if 'layers' in key and 'q_proj' in key:
+ # 三つを結合
+ key_q = key
+ key_k = key.replace("q_proj", "k_proj")
+ key_v = key.replace("q_proj", "v_proj")
+
+ value_q = checkpoint[key_q]
+ value_k = checkpoint[key_k]
+ value_v = checkpoint[key_v]
+ value = torch.cat([value_q, value_k, value_v])
+
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
+ new_sd[new_key] = value
+
+ # 最後の層などを捏造するか
+ if make_dummy_weights:
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
+ keys = list(new_sd.keys())
+ for key in keys:
+ if key.startswith("transformer.resblocks.22."):
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key]
+
+ # Diffusersに含まれない重みを作っておく
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
+ new_sd['logit_scale'] = torch.tensor(1)
+
+ return new_sd
+
+
+def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
+ if ckpt_path is not None:
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
+ if checkpoint is None: # safetensors または state_dictのckpt
+ checkpoint = {}
+ strict = False
+ else:
+ strict = True
+ if "state_dict" in state_dict:
+ del state_dict["state_dict"]
+ else:
+ # 新しく作る
+ checkpoint = {}
+ state_dict = {}
+ strict = False
+
+ def update_sd(prefix, sd):
+ for k, v in sd.items():
+ key = prefix + k
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
+ if save_dtype is not None:
+ v = v.detach().clone().to("cpu").to(save_dtype)
+ state_dict[key] = v
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
+ update_sd("model.diffusion_model.", unet_state_dict)
+
+ # Convert the text encoder model
+ if v2:
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
+ update_sd("cond_stage_model.model.", text_enc_dict)
+ else:
+ text_enc_dict = text_encoder.state_dict()
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
+
+ # Convert the VAE
+ if vae is not None:
+ vae_dict = convert_vae_state_dict(vae.state_dict())
+ update_sd("first_stage_model.", vae_dict)
+
+ # Put together new checkpoint
+ key_count = len(state_dict.keys())
+ new_ckpt = {'state_dict': state_dict}
+
+ if 'epoch' in checkpoint:
+ epochs += checkpoint['epoch']
+ if 'global_step' in checkpoint:
+ steps += checkpoint['global_step']
+
+ new_ckpt['epoch'] = epochs
+ new_ckpt['global_step'] = steps
+
+ if is_safetensors(output_file):
+ # TODO Tensor以外のdictの値を削除したほうがいいか
+ save_file(state_dict, output_file)
+ else:
+ torch.save(new_ckpt, output_file)
+
+ return key_count
+
+
+def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):
+ if vae is None:
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+ pipeline = StableDiffusionPipeline(
+ unet=unet,
+ text_encoder=text_encoder,
+ vae=vae,
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
+ tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=None,
+ )
+ pipeline.save_pretrained(output_dir)
+
+
+VAE_PREFIX = "first_stage_model."
+
+
+def load_vae(vae_id, dtype):
+ print(f"load VAE: {vae_id}")
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
+ # Diffusers local/remote
+ try:
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
+ except EnvironmentError as e:
+ print(f"exception occurs in loading vae: {e}")
+ print("retry with subfolder='vae'")
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
+ return vae
+
+ # local
+ vae_config = create_vae_diffusers_config()
+
+ if vae_id.endswith(".bin"):
+ # SD 1.5 VAE on Huggingface
+ vae_sd = torch.load(vae_id, map_location="cpu")
+ converted_vae_checkpoint = vae_sd
+ else:
+ # StableDiffusion
+ vae_model = torch.load(vae_id, map_location="cpu")
+ vae_sd = vae_model['state_dict']
+
+ # vae only or full model
+ full_model = False
+ for vae_key in vae_sd:
+ if vae_key.startswith(VAE_PREFIX):
+ full_model = True
+ break
+ if not full_model:
+ sd = {}
+ for key, value in vae_sd.items():
+ sd[VAE_PREFIX + key] = value
+ vae_sd = sd
+ del sd
+
+ # Convert the VAE model.
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_checkpoint)
+ return vae
+
+
+def get_epoch_ckpt_name(use_safetensors, epoch):
+ return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
+
+
+def get_last_ckpt_name(use_safetensors):
+ return f"last" + (".safetensors" if use_safetensors else ".ckpt")
+
+# endregion
+
+
+def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
+ max_width, max_height = max_reso
+ max_area = (max_width // divisible) * (max_height // divisible)
+
+ resos = set()
+
+ size = int(math.sqrt(max_area)) * divisible
+ resos.add((size, size))
+
+ size = min_size
+ while size <= max_size:
+ width = size
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
+ resos.add((width, height))
+ resos.add((height, width))
+
+ # # make additional resos
+ # if width >= height and width - divisible >= min_size:
+ # resos.add((width - divisible, height))
+ # resos.add((height, width - divisible))
+ # if height >= width and height - divisible >= min_size:
+ # resos.add((width, height - divisible))
+ # resos.add((height - divisible, width))
+
+ size += divisible
+
+ resos = list(resos)
+ resos.sort()
+
+ aspect_ratios = [w / h for w, h in resos]
+ return resos, aspect_ratios
+
+
+if __name__ == '__main__':
+ resos, aspect_ratios = make_bucket_resolutions((512, 768))
+ print(len(resos))
+ print(resos)
+ print(aspect_ratios)
+
+ ars = set()
+ for ar in aspect_ratios:
+ if ar in ars:
+ print("error! duplicate ar:", ar)
+ ars.add(ar)
diff --git a/diffuser_fine_tuning/prepare_buckets_latents.py b/diffuser_fine_tuning/prepare_buckets_latents.py
new file mode 100644
index 00000000..25dd73ba
--- /dev/null
+++ b/diffuser_fine_tuning/prepare_buckets_latents.py
@@ -0,0 +1,175 @@
+# このスクリプトのライセンスは、Apache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+import argparse
+import glob
+import os
+import json
+
+from tqdm import tqdm
+import numpy as np
+from diffusers import AutoencoderKL
+from PIL import Image
+import cv2
+import torch
+from torchvision import transforms
+
+import model_util
+
+DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+IMAGE_TRANSFORMS = transforms.Compose(
+ [
+ transforms.ToTensor(),
+ transforms.Normalize([0.5], [0.5]),
+ ]
+)
+
+
+def get_latents(vae, images, weight_dtype):
+ img_tensors = [IMAGE_TRANSFORMS(image) for image in images]
+ img_tensors = torch.stack(img_tensors)
+ img_tensors = img_tensors.to(DEVICE, weight_dtype)
+ with torch.no_grad():
+ latents = vae.encode(img_tensors).latent_dist.sample().float().to("cpu").numpy()
+ return latents
+
+
+def main(args):
+ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.png"))
+ print(f"found {len(image_paths)} images.")
+
+ if os.path.exists(args.in_json):
+ print(f"loading existing metadata: {args.in_json}")
+ with open(args.in_json, "rt", encoding='utf-8') as f:
+ metadata = json.load(f)
+ else:
+ print(f"no metadata / メタデータファイルがありません: {args.in_json}")
+ return
+
+ # # モデル形式のオプション設定を確認する
+ # use_stable_diffusion_format = os.path.isfile(args.model_name_or_path)
+
+ # # モデルを読み込む
+ # if use_stable_diffusion_format:
+ # print("load StableDiffusion checkpoint")
+ # # _, vae, _ = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.model_name_or_path)
+ # else:
+ # print("load Diffusers pretrained models")
+
+ weight_dtype = torch.float32
+ if args.mixed_precision == "fp16":
+ weight_dtype = torch.float16
+ elif args.mixed_precision == "bf16":
+ weight_dtype = torch.bfloat16
+
+ vae = model_util.load_vae(args.model_name_or_path, weight_dtype)
+ vae.eval()
+ vae.to(DEVICE, dtype=weight_dtype)
+
+ # bucketのサイズを計算する
+ max_reso = tuple([int(t) for t in args.max_resolution.split(',')])
+ assert len(max_reso) == 2, f"illegal resolution (not 'width,height') / 画像サイズに誤りがあります。'幅,高さ'で指定してください: {args.max_resolution}"
+
+ bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions(
+ max_reso, args.min_bucket_reso, args.max_bucket_reso)
+
+ # 画像をひとつずつ適切なbucketに割り当てながらlatentを計算する
+ bucket_aspect_ratios = np.array(bucket_aspect_ratios)
+ buckets_imgs = [[] for _ in range(len(bucket_resos))]
+ bucket_counts = [0 for _ in range(len(bucket_resos))]
+ img_ar_errors = []
+ for i, image_path in enumerate(tqdm(image_paths)):
+ image_key = os.path.splitext(os.path.basename(image_path))[0]
+ if image_key not in metadata:
+ metadata[image_key] = {}
+
+ image = Image.open(image_path)
+ if image.mode != 'RGB':
+ image = image.convert("RGB")
+
+ aspect_ratio = image.width / image.height
+ ar_errors = bucket_aspect_ratios - aspect_ratio
+ bucket_id = np.abs(ar_errors).argmin()
+ reso = bucket_resos[bucket_id]
+ ar_error = ar_errors[bucket_id]
+ img_ar_errors.append(abs(ar_error))
+
+ # どのサイズにリサイズするか→トリミングする方向で
+ if ar_error <= 0: # 横が長い→縦を合わせる
+ scale = reso[1] / image.height
+ else:
+ scale = reso[0] / image.width
+
+ resized_size = (int(image.width * scale + .5), int(image.height * scale + .5))
+
+ # print(image.width, image.height, bucket_id, bucket_resos[bucket_id], ar_errors[bucket_id], resized_size,
+ # bucket_resos[bucket_id][0] - resized_size[0], bucket_resos[bucket_id][1] - resized_size[1])
+
+ assert resized_size[0] == reso[0] or resized_size[1] == reso[
+ 1], f"internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}"
+ assert resized_size[0] >= reso[0] and resized_size[1] >= reso[
+ 1], f"internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}"
+
+ # 画像をリサイズしてトリミングする
+ # PILにinter_areaがないのでcv2で……
+ image = np.array(image)
+ image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)
+ if resized_size[0] > reso[0]:
+ trim_size = resized_size[0] - reso[0]
+ image = image[:, trim_size//2:trim_size//2 + reso[0]]
+ elif resized_size[1] > reso[1]:
+ trim_size = resized_size[1] - reso[1]
+ image = image[trim_size//2:trim_size//2 + reso[1]]
+ assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f"internal error, illegal trimmed size: {image.shape}, {reso}"
+
+ # # debug
+ # cv2.imwrite(f"r:\\test\\img_{i:05d}.jpg", image[:, :, ::-1])
+
+ # バッチへ追加
+ buckets_imgs[bucket_id].append((image_key, reso, image))
+ bucket_counts[bucket_id] += 1
+ metadata[image_key]['train_resolution'] = reso
+
+ # バッチを推論するか判定して推論する
+ is_last = i == len(image_paths) - 1
+ for j in range(len(buckets_imgs)):
+ bucket = buckets_imgs[j]
+ if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
+ latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
+
+ for (image_key, reso, _), latent in zip(bucket, latents):
+ np.savez(os.path.join(args.train_data_dir, image_key), latent)
+
+ bucket.clear()
+
+ for i, (reso, count) in enumerate(zip(bucket_resos, bucket_counts)):
+ print(f"bucket {i} {reso}: {count}")
+ img_ar_errors = np.array(img_ar_errors)
+ print(f"mean ar error: {np.mean(img_ar_errors)}")
+
+ # metadataを書き出して終わり
+ print(f"writing metadata: {args.out_json}")
+ with open(args.out_json, "wt", encoding='utf-8') as f:
+ json.dump(metadata, f, indent=2)
+ print("done!")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("in_json", type=str, help="metadata file to input / 読み込むメタデータファイル")
+ parser.add_argument("out_json", type=str, help="metadata file to output / メタデータファイル書き出し先")
+ parser.add_argument("model_name_or_path", type=str, help="model name or path to encode latents / latentを取得するためのモデル")
+ parser.add_argument("--v2", action='store_true',
+ help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
+ parser.add_argument("--max_resolution", type=str, default="512,512",
+ help="max resolution in fine tuning (width,height) / fine tuning時の最大画像サイズ 「幅,高さ」(使用メモリ量に関係します)")
+ parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
+ parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
+ parser.add_argument("--mixed_precision", type=str, default="no",
+ choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
+
+ args = parser.parse_args()
+ main(args)
diff --git a/diffuser_fine_tuning/requirements.txt b/diffuser_fine_tuning/requirements.txt
new file mode 100644
index 00000000..38f3852b
--- /dev/null
+++ b/diffuser_fine_tuning/requirements.txt
@@ -0,0 +1,8 @@
+accelerate==0.14.0
+transformers>=4.21.0
+ftfy
+albumentations
+opencv-python
+einops
+pytorch_lightning
+safetensors
diff --git a/diffuser_fine_tuning/tag_images_by_wd14_tagger.py b/diffuser_fine_tuning/tag_images_by_wd14_tagger.py
new file mode 100644
index 00000000..66d3a34e
--- /dev/null
+++ b/diffuser_fine_tuning/tag_images_by_wd14_tagger.py
@@ -0,0 +1,107 @@
+# このスクリプトのライセンスは、Apache License 2.0とします
+# (c) 2022 Kohya S. @kohya_ss
+
+import argparse
+import csv
+import glob
+import os
+import json
+
+from PIL import Image
+from tqdm import tqdm
+import numpy as np
+from tensorflow.keras.models import load_model
+from Utils import dbimutils
+
+
+# from wd14 tagger
+IMAGE_SIZE = 448
+
+
+def main(args):
+ image_paths = glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + \
+ glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
+ print(f"found {len(image_paths)} images.")
+
+ print("loading model and labels")
+ model = load_model(args.model)
+
+ # label_names = pd.read_csv("2022_0000_0899_6549/selected_tags.csv")
+ # 依存ライブラリを増やしたくないので自力で読むよ
+ with open(args.tag_csv, "r", encoding="utf-8") as f:
+ reader = csv.reader(f)
+ l = [row for row in reader]
+ header = l[0] # tag_id,name,category,count
+ rows = l[1:]
+ assert header[0] == 'tag_id' and header[1] == 'name' and header[2] == 'category', f"unexpected csv format: {header}"
+
+ tags = [row[1] for row in rows[1:] if row[2] == '0'] # categoryが0、つまり通常のタグのみ
+
+ # 推論する
+ def run_batch(path_imgs):
+ imgs = np.array([im for _, im in path_imgs])
+
+ probs = model(imgs, training=False)
+ probs = probs.numpy()
+
+ for (image_path, _), prob in zip(path_imgs, probs):
+ # 最初の4つはratingなので無視する
+ # # First 4 labels are actually ratings: pick one with argmax
+ # ratings_names = label_names[:4]
+ # rating_index = ratings_names["probs"].argmax()
+ # found_rating = ratings_names[rating_index: rating_index + 1][["name", "probs"]]
+
+ # それ以降はタグなのでconfidenceがthresholdより高いものを追加する
+ # Everything else is tags: pick any where prediction confidence > threshold
+ tag_text = ""
+ for i, p in enumerate(prob[4:]): # numpyとか使うのが良いけど、まあそれほど数も多くないのでループで
+ if p >= args.thresh:
+ tag_text += ", " + tags[i]
+
+ if len(tag_text) > 0:
+ tag_text = tag_text[2:] # 最初の ", " を消す
+
+ with open(os.path.splitext(image_path)[0] + args.caption_extension, "wt", encoding='utf-8') as f:
+ f.write(tag_text + '\n')
+ if args.debug:
+ print(image_path, tag_text)
+
+ b_imgs = []
+ for image_path in tqdm(image_paths):
+ img = dbimutils.smart_imread(image_path)
+ img = dbimutils.smart_24bit(img)
+ img = dbimutils.make_square(img, IMAGE_SIZE)
+ img = dbimutils.smart_resize(img, IMAGE_SIZE)
+ img = img.astype(np.float32)
+ b_imgs.append((image_path, img))
+
+ if len(b_imgs) >= args.batch_size:
+ run_batch(b_imgs)
+ b_imgs.clear()
+ if len(b_imgs) > 0:
+ run_batch(b_imgs)
+
+ print("done!")
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument("train_data_dir", type=str, help="directory for train images / 学習画像データのディレクトリ")
+ parser.add_argument("--model", type=str, default="networks/ViTB16_11_03_2022_07h05m53s",
+ help="model path to load / 読み込むモデルファイル")
+ parser.add_argument("--tag_csv", type=str, default="2022_0000_0899_6549/selected_tags.csv",
+ help="csv file for tags / タグ一覧のCSVファイル")
+ parser.add_argument("--thresh", type=float, default=0.35, help="threshold of confidence to add a tag / タグを追加するか判定する閾値")
+ parser.add_argument("--batch_size", type=int, default=1, help="batch size in inference / 推論時のバッチサイズ")
+ parser.add_argument("--caption_extention", type=str, default=None,
+ help="extension of caption file (for backward compatibility) / 出力されるキャプションファイルの拡張子(スペルミスしていたのを残してあります)")
+ parser.add_argument("--caption_extension", type=str, default=".txt", help="extension of caption file / 出力されるキャプションファイルの拡張子")
+ parser.add_argument("--debug", action="store_true", help="debug mode")
+
+ args = parser.parse_args()
+
+ # スペルミスしていたオプションを復元する
+ if args.caption_extention is not None:
+ args.caption_extension = args.caption_extention
+
+ main(args)
diff --git a/kohya-dreambooth.ipynb b/kohya-dreambooth.ipynb
index 1927258c..edb72ecb 100644
--- a/kohya-dreambooth.ipynb
+++ b/kohya-dreambooth.ipynb
@@ -24,13 +24,13 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
"cell_type": "markdown",
"source": [
- "#Kohya Dreambooth V13 - VRAM 12GB"
+ "#Kohya Dreambooth V15 - VRAM 12GB"
],
"metadata": {
"id": "slgjeYgd6pWp"
@@ -39,11 +39,10 @@
{
"cell_type": "markdown",
"source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)
\n",
+ "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nee3ed1649fb6)
\n",
"Adapted again from [bmaltais's Kohya Archive](https://github.com/bmaltais/kohya_ss)
\n",
"Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)
\n",
- "Inference cell adapted from [ShivamShrirao's Dreambooth](https://colab.research.google.com/github/ShivamShrirao/diffusers/blob/main/examples/dreambooth/DreamBooth_Stable_Diffusion.ipynb)
\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-dreambooth-beta.ipynb)\n",
+ "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-dreambooth.ipynb)\n",
"\n",
"\n"
],
@@ -70,8 +69,24 @@
"outputs": [],
"source": [
"#@title Clone Kohya Trainer\n",
+ "#@markdown Clone the Kohya Trainer repository from GitHub and check for updates\n",
+ "\n",
"%cd /content/\n",
- "!git clone https://github.com/Linaqruf/kohya-trainer"
+ "\n",
+ "import os\n",
+ "\n",
+ "def clone_kohya_trainer():\n",
+ " # Check if the directory already exists\n",
+ " if os.path.isdir('/content/kohya-trainer'):\n",
+ " %cd /content/kohya-trainer\n",
+ " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
+ " !git pull\n",
+ " else:\n",
+ " !git clone https://github.com/Linaqruf/kohya-trainer\n",
+ " \n",
+ "\n",
+ "# Clone or update the Kohya Trainer repository\n",
+ "clone_kohya_trainer()"
]
},
{
@@ -80,40 +95,30 @@
"#@title Installing Dependencies\n",
"%cd /content/kohya-trainer\n",
"\n",
- "Install_Python_3_9_6 = False #@param{'type':'boolean'}\n",
- "\n",
- "if Install_Python_3_9_6 == True:\n",
- " #install python 3.9\n",
- " !sudo apt-get update -y\n",
- " !sudo apt-get install python3.9\n",
- "\n",
- " #change alternatives\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2\n",
- "\n",
- " #check python version\n",
- " !python --version\n",
- " #3.9.6\n",
- " !sudo apt-get install python3.9-distutils && wget https://bootstrap.pypa.io/get-pip.py && python get-pip.py\n",
+ "def install_dependencies():\n",
+ " #@markdown Install required Python packages\n",
+ " !pip install --upgrade -r script/requirements.txt\n",
+ " !pip install -U gallery-dl\n",
+ " !pip install tensorflow\n",
+ " !pip install huggingface_hub\n",
"\n",
- "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
+ " # Install xformers\n",
+ " !pip install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl\n",
"\n",
- "if os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
- " pass\n",
- "else:\n",
+ " #Install Anime Face Detector\n",
+ " !pip install openmim\n",
+ " !mim install mmcv-full\n",
+ " !mim install mmdet\n",
+ " !mim install mmpose\n",
+ " !pip install anime-face-detector\n",
+ " !pip install --upgrade numpy\n",
+ " \n",
+ "# Install convert_diffusers_to_original_stable_diffusion.py script\n",
+ "if not os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
" !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
"\n",
- "!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113\n",
- "!pip install -r script/requirements.txt\n",
- "!pip install -U gallery-dl\n",
- "!pip install tensorflow\n",
- "!pip install --upgrade numpy\n",
- "\n",
- "#install xformers\n",
- "if Install_Python_3_9_6 == True:\n",
- " %pip install -q https://github.com/daswer123/stable-diffusion-colab/raw/main/xformers%20prebuild/T4/python39/xformers-0.0.14.dev0-cp39-cp39-linux_x86_64.whl\n",
- "else:\n",
- " %pip install -q https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl\n"
+ "# Install dependencies\n",
+ "install_dependencies()"
],
"metadata": {
"id": "WNn0g1pnHfk5",
@@ -183,31 +188,82 @@
{
"cell_type": "code",
"source": [
- "%cd /content\n",
+ "#@title Create train and reg folder based on description above\n",
"\n",
+ "# Import the os and shutil modules\n",
"import os\n",
+ "import shutil\n",
+ "\n",
+ "# Change the current working directory to /content\n",
+ "%cd /content\n",
"\n",
- "if os.path.isdir('/content/dreambooth'):\n",
+ "# Define the dreambooth_directory variable\n",
+ "dreambooth_directory = \"/content/dreambooth\"\n",
+ "\n",
+ "# Check if the dreambooth directory already exists\n",
+ "if os.path.isdir(dreambooth_directory):\n",
+ " # If the directory exists, do nothing\n",
" pass\n",
"else:\n",
- " !mkdir /content/dreambooth\n",
+ " # If the directory does not exist, create it\n",
+ " os.mkdir(dreambooth_directory)\n",
"\n",
- "#@title Create train and reg folder based on description above\n",
- "#@markdown #Create reg folder\n",
+ "#@markdown ### Define the reg_folder variable\n",
"reg_count = 1 #@param {type: \"integer\"}\n",
"reg_class =\"kasakai_hikaru\" #@param {type: \"string\"}\n",
"reg_folder = str(reg_count) + \"_\" + reg_class\n",
"\n",
- "#@markdown #Create train folder\n",
+ "# Define the reg_directory variable\n",
+ "reg_directory = f\"{dreambooth_directory}/reg_{reg_class}\"\n",
+ "\n",
+ "# Check if the reg directory already exists\n",
+ "if os.path.isdir(reg_directory):\n",
+ " # If the directory exists, do nothing\n",
+ " pass\n",
+ "else:\n",
+ " # If the directory does not exist, create it\n",
+ " os.mkdir(reg_directory)\n",
+ "\n",
+ "# Define the reg_folder_directory variable\n",
+ "reg_folder_directory = f\"{reg_directory}/{reg_folder}\"\n",
+ "\n",
+ "# Check if the reg_folder directory already exists\n",
+ "if os.path.isdir(reg_folder_directory):\n",
+ " # If the directory exists, do nothing\n",
+ " pass\n",
+ "else:\n",
+ " # If the directory does not exist, create it\n",
+ " os.mkdir(reg_folder_directory)\n",
+ "\n",
+ "#@markdown ### Define the train_folder variable\n",
"train_count = 3300 #@param {type: \"integer\"}\n",
"train_token = \"sls\" #@param {type: \"string\"}\n",
"train_class = \"kasakai_hikaru\" #@param {type: \"string\"}\n",
"train_folder = str(train_count) + \"_\" + train_token + \"_\" + train_class\n",
"\n",
- "!mkdir \"/content/dreambooth/reg_{reg_class}\"\n",
- "!mkdir \"/content/dreambooth/reg_{reg_class}/{reg_folder}\"\n",
- "!mkdir \"/content/dreambooth/train_{train_class}\"\n",
- "!mkdir \"/content/dreambooth/train_{train_class}/{train_folder}\"\n"
+ "# Define the train_directory variable\n",
+ "train_directory = f\"{dreambooth_directory}/train_{train_class}\"\n",
+ "\n",
+ "# Check if the train directory already exists\n",
+ "if os.path.isdir(train_directory):\n",
+ " # If the directory exists, do nothing\n",
+ " pass\n",
+ "else:\n",
+ " # If the directory does not exist, create it\n",
+ " os.mkdir(train_directory)\n",
+ "\n",
+ "# Define the train_folder_directory variable\n",
+ "train_folder_directory = f\"{train_directory}/{train_folder}\"\n",
+ "\n",
+ "# Check if the train_folder directory already exists\n",
+ "if os.path.isdir(train_folder_directory):\n",
+ " # If the directory exists, do nothing\n",
+ " pass\n",
+ "else:\n",
+ " # If the directory does not exist, create it\n",
+ " os.mkdir(train_folder_directory)\n",
+ "\n",
+ " \n"
],
"metadata": {
"id": "-CVfXAJMSqRi",
@@ -230,24 +286,32 @@
"source": [
"#@title Prepare Regularization Images\n",
"#@markdown Download regularization images provided by community\n",
- "category = \"waifu-regularization-3.3k\" #@param [\"\", \"waifu-regularization-3.3k\", \"husbando-regularization-3.5k\"]\n",
- "#@markdown Or you can use the file manager on the left panel to upload (drag and drop) to `reg_images` folder (it uploads faster)\n",
+ "\n",
+ "import os\n",
+ "import shutil\n",
+ "\n",
+ "# Function to download and unzip regularization images\n",
"def reg_images(url, name):\n",
" user_token = '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'\n",
" user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
+ "\n",
+ " # Use wget to download the zip file\n",
" !wget -c --header={user_header} \"{url}\" -O /content/dreambooth/reg_{reg_class}/{reg_folder}/{name}.zip\n",
"\n",
+ " # Unzip the downloaded file using shutil\n",
+ " shutil.unpack_archive(os.path.join('/content/dreambooth/reg_{reg_class}/{reg_folder}', f'{name}.zip'), os.path.join('/content/dreambooth/reg_{reg_class}/{reg_folder}'))\n",
+ "\n",
+ " # Remove the zip file after extracting\n",
+ " os.remove(os.path.join('/content/dreambooth/reg_{reg_class}/{reg_folder}', f'{name}.zip'))\n",
+ "\n",
+ "category = \"waifu-regularization-3.3k\" #@param [\"\", \"waifu-regularization-3.3k\", \"husbando-regularization-3.5k\"]\n",
+ "#@markdown Or you can use the file manager on the left panel to upload (drag and drop) to `reg_images` folder (it uploads faster)\n",
+ "\n",
"if category != \"\":\n",
" if category == \"waifu-regularization-3.3k\":\n",
" reg_images(\"https://huggingface.co/datasets/waifu-research-department/regularization/resolve/main/waifu-regularization-3.3k.zip\", \"waifu-regularization-3.3k\")\n",
- " !unzip /content/dreambooth/reg_{reg_class}/{reg_folder}/waifu-regularization-3.3k.zip -d /content/dreambooth/reg_{reg_class}/{reg_folder}\n",
- " !rm /content/dreambooth/reg_{reg_class}/{reg_folder}/waifu-regularization-3.3k.zip\n",
" else:\n",
- " reg_images(\"https://huggingface.co/datasets/waifu-research-department/regularization/resolve/main/husbando-regularization-3.5k.zip\", \"husbando-regularization-3.5k\")\n",
- " !unzip /content/dreambooth/reg_{reg_class}/{reg_folder}/husbando-regularization-3.5k.zip -d /content/dreambooth/reg_{reg_class}/{reg_folder}\n",
- " !rm /content/dreambooth/reg_{reg_class}/{reg_folder}/husbando-regularization-3.5k.zip\n",
- " \n",
- "\n"
+ " reg_images(\"https://huggingface.co/datasets/waifu-research-department/regularization/resolve/main/husbando-regularization-3.5k.zip\", \"husbando-regularization-3.5k\")\n"
],
"metadata": {
"cellView": "form",
@@ -259,36 +323,29 @@
{
"cell_type": "code",
"source": [
- "#@title Prepare Train Images\n",
- "#@markdown **How this work?**\n",
- "\n",
- "#@markdown By using **gallery-dl** we can scrap or bulk download images on Internet, on this notebook we will scrap images from Danbooru using tag1 and tag2 as target scraping.\n",
"#@title Booru Scraper\n",
+ "#@markdown Use gallery-dl to scrape images from a booru site using the specified tags\n",
+ "\n",
"%cd /content\n",
"\n",
- "tag = \"kasakai_hikaru\" #@param {type: \"string\"}\n",
+ "# Set configuration options\n",
+ "booru = \"Danbooru\" #@param [\"\", \"Danbooru\", \"Gelbooru\"]\n",
+ "tag1 = \"hito_komoru\" #@param {type: \"string\"}\n",
"tag2 = \"\" #@param {type: \"string\"}\n",
"\n",
- "booru = \"Gelbooru\" #@param [\"\", \"Danbooru\", \"Gelbooru\"]\n",
- "\n",
+ "# Construct the search query\n",
"if tag2 != \"\":\n",
- " tag = tag + \"+\" + tag2\n",
+ " tags = tag1 + \"+\" + tag2\n",
"else:\n",
- " tag = tag\n",
+ " tags = tag1\n",
"\n",
- "output_dir = \"/content/dreambooth/train_\"+ train_class +\"/\"+ train_folder\n",
- "\n",
- "if booru == \"Danbooru\":\n",
- " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tag}\" -D \"{output_dir}\"\n",
- "elif booru == \"Gelbooru\":\n",
- " !gallery-dl \"https://gelbooru.com/index.php?page=post&s=list&tags={tag}\" -D \"{output_dir}\"\n",
+ "# Scrape images from the specified booru site using the given tags\n",
+ "if booru.lower() == \"danbooru\":\n",
+ " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tags}\" -D {train_folder_directory}\n",
+ "elif booru.lower() == \"gelbooru\":\n",
+ " !gallery-dl \"https://gelbooru.com/index.php?page=post&s=list&tags={tags}\" -D {train_folder_directory}\n",
"else:\n",
- " pass\n",
- "\n",
- "#@markdown Or you can use the file manager on the left panel to upload (drag and drop) to `train_images` folder. \n",
- "\n",
- "#@markdown The output directory will be on `/content/dreambooth/reg_{reg_class}/{reg_folder}`. We also will use this folder as target folder for training next step.\n",
- "\n"
+ " print(f\"Unknown booru site: {booru}\")\n"
],
"metadata": {
"id": "Kt1GzntK_apb",
@@ -354,31 +411,6 @@
"id": "7BBiH3bkg88d"
}
},
- {
- "cell_type": "code",
- "source": [
- "#@title Install Anime Face Detector\n",
- "%cd /content\n",
- "if Install_Python_3_9_6 == True:\n",
- " !sudo apt-get install python3.9-dev\n",
- "else:\n",
- " pass\n",
- "\n",
- "# installation\n",
- "!pip install openmim\n",
- "!mim install mmcv-full\n",
- "!mim install mmdet\n",
- "!mim install mmpose\n",
- "!pip install anime-face-detector\n",
- "!pip install --upgrade numpy\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "AO_K4umgOlIR"
- },
- "execution_count": null,
- "outputs": []
- },
{
"cell_type": "code",
"source": [
@@ -386,18 +418,17 @@
"%cd /content/kohya-trainer\n",
"import shutil\n",
"\n",
- "train_data =\"/content/dreambooth/train_kasakai_hikaru/3300_sls_kasakai_hikaru\" #@param {'type':'string'}\n",
"tmp = \"/content/dreambooth/tmp\"\n",
"\n",
"if os.path.isdir(tmp):\n",
" !rm -rf {tmp}\n",
- " shutil.move (train_data, tmp)\n",
- " !mkdir {train_data}\n",
+ " shutil.move (train_folder_directory, tmp)\n",
+ " !mkdir {train_folder_directory}\n",
"else:\n",
- " shutil.move (train_data, tmp)\n",
- " !mkdir {train_data}\n",
+ " shutil.move (train_folder_directory, tmp)\n",
+ " !mkdir {train_folder_directory}\n",
"\n",
- "!python script/detect_face_rotate.py --src_dir {tmp} --dst_dir {train_data} --rotate\n",
+ "!python script/detect_face_rotate_v2.py --src_dir {tmp} --dst_dir {train_folder_directory} --rotate\n",
"\n",
"#@markdown Args list:\n",
"#@markdown - `--src_dir` : directory to load images\n",
@@ -415,97 +446,22 @@
"execution_count": null,
"outputs": []
},
- {
- "cell_type": "markdown",
- "source": [
- "#`(NEW)` Waifu Diffusion 1.4 Autotagger"
- ],
- "metadata": {
- "id": "SoPUJaTpTusz"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Download Weight\n",
- "%cd /content/kohya-trainer/\n",
- "\n",
- "import os\n",
- "import shutil\n",
- "\n",
- "def huggingface_dl(url, weight):\n",
- " user_token = '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'\n",
- " user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
- " !wget -c --header={user_header} {url} -O /content/kohya-trainer/wd14tagger-weight/{weight}\n",
- "\n",
- "def download_weight():\n",
- " if os.path.isdir('/content/kohya-trainer/wd14tagger-weight'):\n",
- " pass\n",
- " else:\n",
- " # Move the content of \n",
- " # source to destination \n",
- " !mkdir /content/kohya-trainer/wd14tagger-weight/\n",
- "\n",
- " huggingface_dl(\"https://huggingface.co/Linaqruf/personal_backup/resolve/main/wd14tagger-weight/wd14Tagger.zip\", \"wd14Tagger.zip\")\n",
- " \n",
- " !unzip /content/kohya-trainer/wd14tagger-weight/wd14Tagger.zip -d /content/kohya-trainer/wd14tagger-weight\n",
- "\n",
- " # Destination path \n",
- " destination = '/content/kohya-trainer/wd14tagger-weight'\n",
- "\n",
- " if os.path.isfile('/content/kohya-trainer/wd14tagger-weight/tag_images_by_wd14_tagger.py'):\n",
- " pass\n",
- " else:\n",
- " # Move the content of \n",
- " # source to destination \n",
- " shutil.move(\"script/tag_images_by_wd14_tagger.py\", destination) \n",
- "\n",
- "download_weight()"
- ],
- "metadata": {
- "id": "WDSlAEHzT2Im",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Start Autotagger\n",
- "%cd /content/kohya-trainer/wd14tagger-weight\n",
- "\n",
- "!python tag_images_by_wd14_tagger.py --batch_size 4 {train_data}\n",
- "\n",
- "#@markdown Args list:\n",
- "#@markdown - `--train_data_dir` : directory for training images\n",
- "#@markdown - `--model` : model path to load\n",
- "#@markdown - `--tag_csv` : csv file for tag\n",
- "#@markdown - `--thresh` : threshold of confidence to add a tag\n",
- "#@markdown - `--batch_size` : batch size in inference\n",
- "#@markdown - `--model` : model path to load\n",
- "#@markdown - `--caption_extension` : extension of caption file\n",
- "#@markdown - `--debug` : debug mode\n"
- ],
- "metadata": {
- "id": "hibZK5NPTjZQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
{
"cell_type": "code",
"source": [
"#@title Install Pre-trained Model \n",
"%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
+ "import os\n",
+ "\n",
+ "# Check if directory exists\n",
+ "if not os.path.exists('checkpoint'):\n",
+ " # Create directory if it doesn't exist\n",
+ " os.makedirs('checkpoint')\n",
"\n",
"#@title Install Pre-trained Model \n",
"\n",
"installModels=[]\n",
"\n",
- "\n",
"#@markdown ### Available Model\n",
"#@markdown Select one of available pretrained model to download:\n",
"modelUrl = [\"\", \\\n",
@@ -524,44 +480,52 @@
" \"Anything-V3.0-pruned-fp32\", \\\n",
" \"Anything-V3.0-pruned\", \\\n",
" \"Stable-Diffusion-v1-4\", \\\n",
- " \"Stable-Diffusion-v1-5-pruned-emaonly\" \\\n",
+ " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n",
" \"Waifu-Diffusion-v1-3-fp32\"]\n",
- "modelName = \"Anything-V3.0-pruned-fp32\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
+ "modelName = \"Animefull-final-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
"\n",
"#@markdown ### Custom model\n",
"#@markdown The model URL should be a direct download link.\n",
"customName = \"\" #@param {'type': 'string'}\n",
"customUrl = \"\"#@param {'type': 'string'}\n",
"\n",
- "if customName == \"\" or customUrl == \"\":\n",
- " pass\n",
- "else:\n",
+ "# Check if user has specified a custom model\n",
+ "if customName != \"\" and customUrl != \"\":\n",
+ " # Add custom model to list of models to install\n",
" installModels.append((customName, customUrl))\n",
"\n",
+ "# Check if user has selected a model\n",
"if modelName != \"\":\n",
- " # Map model to URL\n",
+ " # Map selected model to URL\n",
" installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n",
"\n",
"def install_aria():\n",
+ " # Install aria2 if it is not already installed\n",
" if not os.path.exists('/usr/bin/aria2c'):\n",
" !apt install -y -qq aria2\n",
"\n",
"def install(checkpoint_name, url):\n",
" if url.startswith(\"https://drive.google.com\"):\n",
+ " # Use gdown to download file from Google Drive\n",
" !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n",
" elif url.startswith(\"magnet:?\"):\n",
" install_aria()\n",
+ " # Use aria2c to download file from magnet link\n",
" !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n",
" else:\n",
" user_token = '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'\n",
" user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
+ " # Use wget to download file from URL\n",
" !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
"\n",
"def install_checkpoint():\n",
+ " # Iterate through list of models to install\n",
" for model in installModels:\n",
+ " # Call install function for each model\n",
" install(model[0], model[1])\n",
- "install_checkpoint()\n",
- "\n"
+ "\n",
+ "# Call install_checkpoint function to download all models in the list\n",
+ "install_checkpoint()\n"
],
"metadata": {
"id": "SoucgZQ6jgPQ",
@@ -588,9 +552,12 @@
"\n",
"diffuser_0_7_2 = True #@param {'type':'boolean'}\n",
"\n",
- "if diffuser_0_7_2 == True :\n",
+ "# Check if user wants to downgrade diffusers\n",
+ "if diffuser_0_7_2:\n",
+ " # Install diffusers 0.7.2\n",
" !pip install diffusers[torch]==0.7.2\n",
"else:\n",
+ " # Install latest version of diffusers\n",
" !pip install diffusers[torch]==0.9.0"
],
"metadata": {
@@ -617,8 +584,8 @@
"save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n",
"save_every_n_epochs = 10 #@param {'type':'integer'}\n",
"\n",
- "%cd /content/kohya-trainer\n",
- "!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} /content/kohya-trainer/train_db_fixed/train_db_fixed_v13.py \\\n",
+ "%cd /content/kohya-trainer/train_db_fixed\n",
+ "!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} train_db_fixed.py \\\n",
" --pretrained_model_name_or_path={pre_trained_model_path} \\\n",
" --train_data_dir={train_data_dir} \\\n",
" --reg_data_dir={reg_data_dir} \\\n",
@@ -650,34 +617,9 @@
{
"cell_type": "code",
"source": [
- "#@title Convert trained model to ckpt\n",
- "\n",
- "#@markdown This cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
- "WEIGHTS_DIR = \"/content/dreambooth/last\" #@param {'type':'string'}\n",
- "#@markdown Run conversion.\n",
- "ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n",
- "\n",
- "half_arg = \"\"\n",
- "#@markdown Whether to convert to fp16, takes half the space (2GB).\n",
- "fp16 = True #@param {type: \"boolean\"}\n",
- "if fp16:\n",
- " half_arg = \"--half\"\n",
- "!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n",
+ "#@title Model Pruner (Optional)\n",
"\n",
- "print(f\"[*] Converted ckpt saved at {ckpt_path}\")"
- ],
- "metadata": {
- "cellView": "form",
- "id": "rM5o1gUu97yc"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@markdown ```\n",
- "#@markdown @lopho\n",
+ "#@markdown ```python\n",
"#@markdown usage: prune.py [-h] [-p] [-e] [-c] [-a] input output\n",
"#@markdown \n",
"#@markdown Prune a stable diffusion checkpoint\n",
@@ -693,24 +635,32 @@
"#@markdown -c, --no-clip strip CLIP weights\n",
"#@markdown -a, --no-vae strip VAE weights\n",
"#@markdown ```\n",
- "#@title .ckpt Model Pruner\n",
+ "\n",
"#@markdown Do you want to Prune a model?\n",
"%cd /content/ \n",
"\n",
- "prune = True #@param {'type':'boolean'}\n",
+ "# Use a more descriptive variable name\n",
+ "should_prune = False #@param {'type':'boolean'}\n",
"\n",
- "model_src = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
- "model_dst = \"/content/kohya-trainer/fine_tuned/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
+ "# Use a more descriptive variable name\n",
+ "source_model_path = \"/content/dreambooth/last.ckpt\" #@param {'type' : 'string'}\n",
"\n",
- "if prune == True:\n",
+ "# Use a more descriptive variable name\n",
+ "pruned_model_path = \"/content/dreambooth/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
+ "\n",
+ "if should_prune:\n",
" import os\n",
" if os.path.isfile('/content/prune.py'):\n",
" pass\n",
" else:\n",
+ " # Add a comment to explain what the code is doing\n",
+ " # Download the pruning script if it doesn't already exist\n",
" !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py\n",
"\n",
"\n",
- "!python3 prune.py -p {model_src} {model_dst}\n"
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the pruning script\n",
+ "!python3 prune.py {source_model_path} {pruned_model_path}"
],
"metadata": {
"cellView": "form",
@@ -722,201 +672,144 @@
{
"cell_type": "code",
"source": [
- "#@title Inference\n",
- "import torch\n",
- "from torch import autocast\n",
- "from diffusers import StableDiffusionPipeline, DDIMScheduler\n",
- "from IPython.display import display\n",
- "\n",
- "model_path = WEIGHTS_DIR # If you want to use previously trained model saved in gdrive, replace this with the full path of model in gdrive\n",
- "\n",
- "scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
- "pipe = StableDiffusionPipeline.from_pretrained(model_path, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16).to(\"cuda\")\n",
- "\n",
- "g_cuda = None\n",
- "\n",
- "#@markdown Can set random seed here for reproducibility.\n",
- "g_cuda = torch.Generator(device='cuda')\n",
- "seed = 123123 #@param {type:\"number\"}\n",
- "g_cuda.manual_seed(seed)\n",
- "\n",
- "#@title Run for generating images.\n",
- "\n",
- "prompt = \"hatsune miku\" #@param {type:\"string\"}\n",
- "negative_prompt = \"\" #@param {type:\"string\"}\n",
- "num_samples = 1 #@param {type:\"number\"}\n",
- "guidance_scale = 7.5 #@param {type:\"number\"}\n",
- "num_inference_steps = 50 #@param {type:\"number\"}\n",
- "height = 512 #@param {type:\"number\"}\n",
- "width = 512 #@param {type:\"number\"}\n",
- "\n",
- "with autocast(\"cuda\"), torch.inference_mode():\n",
- " images = pipe(\n",
- " prompt,\n",
- " height=height,\n",
- " width=width,\n",
- " negative_prompt=negative_prompt,\n",
- " num_images_per_prompt=num_samples,\n",
- " num_inference_steps=num_inference_steps,\n",
- " guidance_scale=guidance_scale,\n",
- " generator=g_cuda\n",
- " ).images\n",
- "\n",
- "for img in images:\n",
- " display(img)"
- ],
- "metadata": {
- "cellView": "form",
- "id": "cKn2ARpLAI0J"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Miscellaneous"
- ],
- "metadata": {
- "id": "vqfgyL-thgdw"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= False #@param {'type':'boolean'}\n",
+ "#@title Convert diffuser model to ckpt (Optional)\n",
"\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Huggingface_hub Integration"
- ],
- "metadata": {
- "id": "QtVP2le8PL2T"
- }
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Instruction:\n",
- "0. Of course you need a Huggingface Account first\n",
- "1. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "2. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells."
- ],
- "metadata": {
- "id": "tbKgmh_AO5NG"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Login to Huggingface hub\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
+ "#@markdown If you're using diffuser weight, this cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
"\n",
- "#@markdown Prepare your Huggingface token\n",
+ "# Use a more descriptive variable name\n",
+ "diffuser_weights_dir = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
"\n",
- "saved_token= \"save-your-write-token-here\" #@param {'type': 'string'}\n",
+ "# Use a more descriptive variable name\n",
+ "use_fp16 = False #@param {type: \"boolean\"}\n",
"\n",
- "if opt_out == False:\n",
- " !pip install huggingface_hub\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Convert the diffuser weights to a checkpoint file\n",
+ "ckpt_path = diffuser_weights_dir + \"/model.ckpt\"\n",
"\n",
- " from huggingface_hub import notebook_login\n",
- " notebook_login()\n",
- "\n"
+ "# Use a more descriptive variable name\n",
+ "half_precision_arg = \"\"\n",
+ "if use_fp16:\n",
+ " # Use a more descriptive variable name\n",
+ " half_precision_arg = \"--half\"\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the conversion script\n",
+ "!python convert_diffusers_to_original_stable_diffusion.py --model_path $diffuser_weights_dir --checkpoint_path $ckpt_path $half_precision_arg\n",
+ "\n",
+ "# Use string formatting and a more descriptive variable name\n",
+ "print(f\"[*] Converted checkpoint saved at {ckpt_path}\")"
],
"metadata": {
- "id": "Da7awoqAPJ3a",
- "cellView": "form"
+ "cellView": "form",
+ "id": "rM5o1gUu97yc"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
- "source": [
- "##Commit trained model to Huggingface"
- ],
"metadata": {
"id": "jypUkLWc48R_"
- }
+ },
+ "source": [
+ "## Commit trained model to Huggingface"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "###Instruction:\n",
- "0. Create huggingface repository for model\n",
- "1. Clone your model to this colab session\n",
- "2. Move these necessary file to your repository to save your trained model to huggingface\n",
- "\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- File `epoch-nnnnn.ckpt` and/or\n",
- "- File `last.ckpt`, \n",
- "\n",
- "4. Commit your model to huggingface"
- ],
"metadata": {
"id": "TvZgRSmKVSRw"
- }
+ },
+ "source": [
+ "### To Commit models:\n",
+ "1. Create a huggingface repository for your model.\n",
+ "2. Clone your model to this Colab session.\n",
+ "3. Move the necessary files to your repository to save your trained model to huggingface. These files are located in `fine-tuned` folder:\n",
+ " - `epoch-nnnnn.ckpt` and/or\n",
+ " - `last.ckpt`\n",
+ "4. Commit your model to huggingface.\n",
+ "\n",
+ "### To Commit datasets:\n",
+ "1. Create a huggingface repository for your datasets.\n",
+ "2. Clone your datasets to this Colab session.\n",
+ "3. Move the necessary files to your repository so that you can resume training without rebuilding your dataset with this notebook.\n",
+ " - The `train_folder` folder.\n",
+ "4. Commit your datasets to huggingface.\n",
+ "\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "182Law9oUiYN"
+ },
+ "outputs": [],
"source": [
- "#@title Clone Model\n",
+ "#@title Clone Model or Datasets\n",
"\n",
"#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
+ "\n",
+ "#@markdown Type of item to clone (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
+ "\n",
+ "#@markdown Install or uninstall git lfs\n",
+ "install_git_lfs = False #@param {'type':'boolean'}\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/alphanime-diffusion\" #@param {'type': 'string'}\n",
+ " username = \"your-huggingface-username\" #@param {'type': 'string'}\n",
+ " model_repo = \"your-huggingface-model-repo\" #@param {'type': 'string'}\n",
+ " datasets_repo = \"your-huggingface-datasets-repo\" #@param {'type': 'string'}\n",
+ " \n",
+ " if type_of_item == \"model\":\n",
+ " Repository_url = f\"https://huggingface.co/{username}/{model_repo}\"\n",
+ " elif type_of_item == \"dataset\":\n",
+ " Repository_url = f\"https://huggingface.co/datasets/{username}/{datasets_repo}\"\n",
+ "\n",
+ " if install_git_lfs:\n",
+ " !git lfs install\n",
+ " else:\n",
+ " !git lfs uninstall\n",
+ "\n",
" !git clone {Repository_url}\n",
"else:\n",
" pass\n"
- ],
- "metadata": {
- "id": "182Law9oUiYN",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "87wG7QIZbtZE"
+ },
+ "outputs": [],
"source": [
- "#@title Commit to Huggingface\n",
+ "#@title Commit Model or Datasets to Huggingface\n",
+ "\n",
"#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
+ "\n",
+ "#@markdown Type of item to commit (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"alphanime-diffusion\" #@param {'type': 'string'}\n",
+ " #@markdown Go to your model or dataset path\n",
+ " item_path = \"your-cloned-model-or-datasets-repo\" #@param {'type': 'string'}\n",
"\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
" #@markdown #Git Commit\n",
"\n",
" #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-name\" #@param {'type': 'string'}\n",
+ " email = \"your-email\" #@param {'type': 'string'}\n",
+ " name = \"your-username\" #@param {'type': 'string'}\n",
" #@markdown Set **commit message**\n",
- " commit_m= \"this is commit message\" #@param {'type': 'string'}\n",
+ " commit_m = \"feat: upload 6 epochs model\" #@param {'type': 'string'}\n",
"\n",
- " %cd \"/content/{model_path}\"\n",
+ " %cd {item_path}\n",
" !git lfs install\n",
" !huggingface-cli lfs-enable-largefiles .\n",
" !git add .\n",
@@ -928,13 +821,7 @@
"\n",
"else:\n",
" pass"
- ],
- "metadata": {
- "id": "87wG7QIZbtZE",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
}
]
}
\ No newline at end of file
diff --git a/kohya-trainer-deepdanbooru.ipynb b/kohya-trainer-deepdanbooru.ipynb
index 9c4eb6e5..f2a0627f 100644
--- a/kohya-trainer-deepdanbooru.ipynb
+++ b/kohya-trainer-deepdanbooru.ipynb
@@ -1,21 +1,4 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "gpuClass": "standard",
- "accelerator": "GPU"
- },
"cells": [
{
"cell_type": "markdown",
@@ -24,159 +7,196 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
"cell_type": "markdown",
- "source": [
- "#Kohya Trainer V4 - VRAM 12GB\n",
- "##But with DeepDanbooru instead of Waifu Diffusion 1.4 Tagger"
- ],
"metadata": {
"id": "slgjeYgd6pWp"
- }
+ },
+ "source": [
+ "# Kohya Trainer V6 - VRAM 12GB - DeepDanbooru Ver.\n",
+ "### The Best Way for People Without Good GPUs to Fine-Tune the Stable Diffusion Model"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)
\n",
- "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)
\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb)\n",
- "\n",
- "\n"
- ],
"metadata": {
"id": "gPgBR3KM6E-Z"
- }
+ },
+ "source": [
+ "This notebook has been adapted for use in Google Colab based on the [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb). \n",
+ "This notebook was adapted by [Linaqruf](https://github.com/Linaqruf)\n",
+ "You can find the latest update to the notebook [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb).\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#Install Kohya Trainer"
- ],
"metadata": {
"id": "tTVqCAgSmie4"
- }
+ },
+ "source": [
+ "# Install Kohya Trainer"
+ ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "_u3q60di584x",
- "cellView": "form"
+ "cellView": "form",
+ "id": "_u3q60di584x"
},
"outputs": [],
"source": [
"#@title Clone Kohya Trainer\n",
- "#@markdown Run this cell everytime you want to `!git pull` to get a lot of new optimizations and updates.\n",
+ "#@markdown Clone the Kohya Trainer repository from GitHub and check for updates\n",
+ "\n",
"%cd /content/\n",
"\n",
"import os\n",
"\n",
- "if os.path.isdir('/content/kohya-trainer'):\n",
- " %cd /content/kohya-trainer\n",
- " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
- " !git pull\n",
- " \n",
- "else:\n",
- " !git clone https://github.com/Linaqruf/kohya-trainer"
+ "def clone_kohya_trainer():\n",
+ " # Check if the directory already exists\n",
+ " if os.path.isdir('/content/kohya-trainer'):\n",
+ " %cd /content/kohya-trainer\n",
+ " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
+ " !git pull\n",
+ " else:\n",
+ " !git clone https://github.com/Linaqruf/kohya-trainer\n",
+ " \n",
+ "\n",
+ "# Clone or update the Kohya Trainer repository\n",
+ "clone_kohya_trainer()"
]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "nj8fNQZNESyT"
+ },
+ "outputs": [],
"source": [
"#@title Install Diffuser Fine Tuning\n",
+ "\n",
+ "# Change the current working directory to \"/content/kohya-trainer\".\n",
"%cd /content/kohya-trainer\n",
"\n",
+ "# Import `shutil` and `os` modules.\n",
"import shutil\n",
"import os\n",
"\n",
- "customVersion = []\n",
- "versionDir = [\"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v2.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v3.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v4.zip\"]\n",
- "versionList = [\"diffusers_fine_tuning_v1\", \\\n",
- " \"diffusers_fine_tuning_v2\", \\\n",
+ "# Initialize an empty list `custom_versions`.\n",
+ "custom_versions = []\n",
+ "\n",
+ "# Initialize a list `version_urls` containing URLs of different versions of the `diffusers_fine_tuning` file.\n",
+ "version_urls = [\"\",\\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v6/diffusers_fine_tuning_v6.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v5/diffusers_fine_tuning_v5.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v4/diffusers_fine_tuning_v4.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v3/diffusers_fine_tuning_v3.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v2/diffusers_fine_tuning_v2.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v1/diffusers_fine_tuning_v1.zip\"]\n",
+ "\n",
+ "# Initialize a list `version_names` containing names of different versions of the `diffusers_fine_tuning` file.\n",
+ "version_names = [\"latest_version\", \\\n",
+ " \"diffusers_fine_tuning_v6\", \\\n",
+ " \"diffusers_fine_tuning_v5\", \\\n",
+ " \"diffusers_fine_tuning_v4\", \\\n",
" \"diffusers_fine_tuning_v3\", \\\n",
- " \"diffusers_fine_tuning_v4\"]\n",
- "version = \"diffusers_fine_tuning_v3\" #@param [\"diffusers_fine_tuning_v1\",\"diffusers_fine_tuning_v2\",\"diffusers_fine_tuning_v3\",\"diffusers_fine_tuning_v4\"]\n",
- "\n",
- "customVersion.append((versionDir[versionList.index(version)]))\n",
- "\n",
- "for zip in customVersion:\n",
- " (zip[0])\n",
- "\n",
- "zip = \"\".join(zip)\n",
- "\n",
- "def unzip_function(dir):\n",
- " !unzip {dir} -d /content/kohya-trainer/\n",
- "\n",
- "def unzip_version():\n",
- " unzip_function(zip)\n",
+ " \"diffusers_fine_tuning_v2\", \\\n",
+ " \"diffusers_fine_tuning_v1\"]\n",
+ "\n",
+ "# Initialize a variable `selected_version` to the selected version of the `diffusers_fine_tuning` file.\n",
+ "selected_version = \"latest_version\" #@param [\"latest_version\", \"diffusers_fine_tuning_v6\", \"diffusers_fine_tuning_v5\", \"diffusers_fine_tuning_v4\", \"diffusers_fine_tuning_v3\", \"diffusers_fine_tuning_v2\", \"diffusers_fine_tuning_v1\"]\n",
+ "\n",
+ "# Append a tuple to `custom_versions`, containing `selected_version` and the corresponding item\n",
+ "# in `version_urls`.\n",
+ "custom_versions.append((selected_version, version_urls[version_names.index(selected_version)]))\n",
+ "\n",
+ "# Define `download` function to download a file from the given URL and save it with\n",
+ "# the given name.\n",
+ "def download(name, url):\n",
+ " !wget -c \"{url}\" -O /content/{name}.zip\n",
+ "\n",
+ "# Define `unzip` function to unzip a file with the given name to a specified\n",
+ "# directory.\n",
+ "def unzip(name):\n",
+ " !unzip /content/{name}.zip -d /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "# Define `download_version` function to download and unzip a file from `custom_versions`,\n",
+ "# unless `selected_version` is \"latest_version\".\n",
+ "def download_version():\n",
+ " if selected_version != \"latest_version\":\n",
+ " for zip in custom_versions:\n",
+ " download(zip[0], zip[1])\n",
+ "\n",
+ " # Rename the existing `diffuser_fine_tuning` directory to the `tmp` directory and delete any existing `tmp` directory.\n",
+ " if os.path.exists(\"/content/kohya-trainer/tmp\"):\n",
+ " shutil.rmtree(\"/content/kohya-trainer/tmp\")\n",
+ " os.rename(\"/content/kohya-trainer/diffuser_fine_tuning\", \"/content/kohya-trainer/tmp\")\n",
+ "\n",
+ " # Create a new empty `diffuser_fine_tuning` directory.\n",
+ " os.makedirs(\"/content/kohya-trainer/diffuser_fine_tuning\")\n",
+ " \n",
+ " # Unzip the downloaded file to the new `diffuser_fine_tuning` directory.\n",
+ " unzip(zip[0])\n",
+ " \n",
+ " # Delete the downloaded and unzipped file.\n",
+ " os.remove(\"/content/{}.zip\".format(zip[0]))\n",
+ " \n",
+ " # Inform the user that the existing `diffuser_fine_tuning` directory has been renamed to the `tmp` directory\n",
+ " # and a new empty `diffuser_fine_tuning` directory has been created.\n",
+ " print(\"Renamed existing 'diffuser_fine_tuning' directory to 'tmp' directory and created new empty 'diffuser_fine_tuning' directory.\")\n",
+ " else:\n",
+ " # Do nothing if `selected_version` is \"latest_version\".\n",
+ " pass\n",
"\n",
- "unzip_version()\n",
- "# if version == \"diffusers_fine_tuning_v1\":\n",
- "# !unzip /content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip -d /content/kohya-trainer/"
- ],
- "metadata": {
- "cellView": "form",
- "id": "nj8fNQZNESyT"
- },
- "execution_count": null,
- "outputs": []
+ "# Call `download_version` function.\n",
+ "download_version()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "WNn0g1pnHfk5"
+ },
+ "outputs": [],
"source": [
"#@title Installing Dependencies\n",
"%cd /content/kohya-trainer\n",
"\n",
- "Install_Python_3_9_6 = False #@param{'type':'boolean'}\n",
- "\n",
- "if Install_Python_3_9_6 == True:\n",
- " #install python 3.9\n",
- " !sudo apt-get update -y\n",
- " !sudo apt-get install python3.9\n",
+ "def install_dependencies():\n",
+ " #@markdown Install required Python packages\n",
+ " !pip install --upgrade -r script/requirements.txt\n",
+ " !pip install -U gallery-dl\n",
+ " !pip install tensorflow\n",
+ " !pip install huggingface_hub\n",
"\n",
- " #change alternatives\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2\n",
+ " # Install xformers\n",
+ " !pip install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl\n",
"\n",
- " #check python version\n",
- " !python --version\n",
- " #3.9.6\n",
- " !sudo apt-get install python3.9-distutils && wget https://bootstrap.pypa.io/get-pip.py && python get-pip.py\n",
"\n",
- "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
- "\n",
- "if os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
- " pass\n",
- "else:\n",
+ "# Install convert_diffusers_to_original_stable_diffusion.py script\n",
+ "if not os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
" !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
"\n",
- "!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113\n",
- "!pip install --upgrade -r script/requirements.txt\n",
- "!pip install -U gallery-dl\n",
- "!pip install tensorflow\n",
- "!pip install accelerate==0.14.0\n",
- "\n",
- "#install xformers\n",
- "if Install_Python_3_9_6 == True:\n",
- " !pip install -U -I --no-deps https://github.com/daswer123/stable-diffusion-colab/raw/main/xformers%20prebuild/T4/python39/xformers-0.0.14.dev0-cp39-cp39-linux_x86_64.whl\n",
- "else:\n",
- " !pip install -U -I --no-deps https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl\n"
- ],
- "metadata": {
- "id": "WNn0g1pnHfk5",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ "# Install dependencies\n",
+ "install_dependencies()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "VZOXwDv3utpx"
+ },
+ "outputs": [],
"source": [
"#@title Set config for `!Accelerate`\n",
"#@markdown #Hint\n",
@@ -190,132 +210,208 @@
"%cd /content/kohya-trainer\n",
"\n",
"!accelerate config"
- ],
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "M0fzmhtywk_u"
+ },
+ "source": [
+ "# Prepare Cloud Storage (Huggingface/GDrive)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {
"cellView": "form",
- "id": "VZOXwDv3utpx"
+ "id": "cwIJdhEcwk_u"
},
+ "outputs": [],
+ "source": [
+ "#@title Login to Huggingface hub\n",
+ "\n",
+ "#@markdown ## Instructions:\n",
+ "#@markdown 1. Of course, you need a Huggingface account first.\n",
+ "#@markdown 2. To create a huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
+ "#@markdown 3. By default, all cells below are marked as `opt-out`, so you need to uncheck them if you want to run the cells.\n",
+ "\n",
+ "%cd /content/kohya-trainer\n",
+ "\n",
+ "from huggingface_hub import login\n",
+ "login()\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
- "outputs": []
+ "metadata": {
+ "cellView": "form",
+ "id": "jVgHUUK_wk_v"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Mount Google Drive\n",
+ "\n",
+ "from google.colab import drive\n",
+ "\n",
+ "mount_drive = True #@param {'type':'boolean'}\n",
+ "\n",
+ "if mount_drive:\n",
+ " drive.mount('/content/drive')"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "En9UUwGNMRMM"
+ },
"source": [
- "#Collecting datasets\n",
- "You can either upload your datasets to this notebook or use image scraper below to bulk download images from danbooru.\n",
+ "# Collecting datasets\n",
"\n",
- "If you want to use your own datasets, make sure to put them in a folder titled `train_data` in `content/kohya-trainer`\n",
+ "You can either upload your datasets to this notebook or use the image scraper below to bulk download images from Danbooru.\n",
"\n",
- "This is to make the training process easier because the folder that will be used for training is in `content/kohya-trainer/train-data`."
+ "If you want to use your own datasets, you can upload to colab `local files`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Define Train Data\n",
+ "#@markdown Define where your train data will be located. This cell will also create a folder based on your input. \n",
+ "#@markdown This folder will be used as the target folder for scraping, tagging, bucketing, and training in the next cell.\n",
+ "\n",
+ "import os\n",
+ "\n",
+ "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type' : 'string'}\n",
+ "\n",
+ "if not os.path.exists(train_data_dir):\n",
+ " os.makedirs(train_data_dir)\n",
+ "else:\n",
+ " print(f\"{train_data_dir} already exists\\n\")\n",
+ "\n",
+ "print(f\"Your train data directory : {train_data_dir}\")\n"
],
"metadata": {
- "id": "En9UUwGNMRMM"
- }
+ "cellView": "form",
+ "id": "nXNk0NOwzWw4"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Kt1GzntK_apb"
+ },
+ "outputs": [],
"source": [
"#@title Booru Scraper\n",
- "#@markdown **How this work?**\n",
+ "#@markdown Use gallery-dl to scrape images from a booru site using the specified tags\n",
"\n",
- "#@markdown By using **gallery-dl** we can scrap or bulk download images on Internet, on this notebook we will scrap images from popular booru sites using tag1 and tag2 as target scraping.\n",
"%cd /content\n",
"\n",
+ "# Set configuration options\n",
"booru = \"Danbooru\" #@param [\"\", \"Danbooru\", \"Gelbooru\"]\n",
- "tag = \"herigaru_(fvgyvr000)\" #@param {type: \"string\"}\n",
+ "tag1 = \"hito_komoru\" #@param {type: \"string\"}\n",
"tag2 = \"\" #@param {type: \"string\"}\n",
"\n",
+ "# Construct the search query\n",
"if tag2 != \"\":\n",
- " tag = tag + \"+\" + tag2\n",
+ " tags = tag1 + \"+\" + tag2\n",
"else:\n",
- " tag = tag\n",
- "\n",
- "output_dir = \"/content/kohya-trainer/train_data\"\n",
+ " tags = tag1\n",
"\n",
- "if booru == \"Danbooru\":\n",
- " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tag}\" -D {output_dir}\n",
- "elif booru == \"Gelbooru\":\n",
- " !gallery-dl \"https://gelbooru.com/index.php?page=post&s=list&tags={tag}\" -D {output_dir}\n",
+ "# Scrape images from the specified booru site using the given tags\n",
+ "if booru.lower() == \"danbooru\":\n",
+ " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tags}\" -D {train_data_dir}\n",
+ "elif booru.lower() == \"gelbooru\":\n",
+ " !gallery-dl \"https://gelbooru.com/index.php?page=post&s=list&tags={tags}\" -D {train_data_dir}\n",
"else:\n",
- " pass\n",
- "\n",
- "\n",
- "#@markdown The output directory will be on /content/kohya-trainer/train_data. We also will use this folder as target folder for training next step.\n"
- ],
- "metadata": {
- "id": "Kt1GzntK_apb",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ " print(f\"Unknown booru site: {booru}\")\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Jz2emq6vWnPu"
+ },
+ "outputs": [],
"source": [
"#@title Datasets cleaner\n",
- "#@markdown This will delete unnecessary file and unsupported media like `.bin`, `.mp4`, `.webm`, and `.gif`\n",
- "import os\n",
+ "#@markdown This will delete unnecessary files and unsupported media like `.mp4`, `.webm`, and `.gif`\n",
"\n",
- "dir_name = \"/content/kohya-trainer/train_data\" #@param {'type' : 'string'}\n",
- "test = os.listdir(dir_name)\n",
+ "%cd /content\n",
"\n",
- "for item in test:\n",
- " if item.endswith(\".mp4\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
+ "import os\n",
+ "test = os.listdir(train_data_dir)\n",
"\n",
- "for item in test:\n",
- " if item.endswith(\".webm\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
+ "# List of supported file types\n",
+ "supported_types = [\".jpg\", \".jpeg\", \".png\"]\n",
"\n",
+ "# Iterate over all files in the directory\n",
"for item in test:\n",
- " if item.endswith(\".gif\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
- " \n",
- "for item in test:\n",
- " if item.endswith(\".webp\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Jz2emq6vWnPu"
- },
- "execution_count": null,
- "outputs": []
+ " # Extract the file extension from the file name\n",
+ " file_ext = os.path.splitext(item)[1]\n",
+ " # If the file extension is not in the list of supported types, delete the file\n",
+ " if file_ext not in supported_types:\n",
+ " # Print a message indicating the name of the file being deleted\n",
+ " print(f\"Deleting file {item} from {train_data_dir}\")\n",
+ " # Delete the file\n",
+ " os.remove(os.path.join(train_data_dir, item))\n"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "SoPUJaTpTusz"
+ },
"source": [
- "#DeepDanbooru3 for Autotagger\n",
- "We will skip BLIP Captioning section and only used DeepDanbooru for Autotagging.\n",
+ "# DeepDanbooru3 for Autotagger\n",
+ "We will skip the BLIP Captioning section and only use DeepDanbooru for Autotagging.\n",
"\n",
- "If you still want to use BLIP, please refer to the original article [here](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)"
- ],
- "metadata": {
- "id": "cSbB9CeqMwbF"
- }
+ "If you still want to use BLIP, please refer to the original article [here](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)\n"
+ ]
},
{
"cell_type": "code",
"source": [
"#@title Install DeepDanbooru\n",
"%cd /content/kohya-trainer\n",
+ "\n",
+ "# This code installs the DeepDanbooru library\n",
+ "\n",
"import os\n",
+ "import shutil\n",
"\n",
+ "# Check if the directory already exists\n",
"if os.path.isdir('/content/kohya-trainer/deepdanbooru'):\n",
" print(\"This folder already exists, will do a !git pull instead\\n\")\n",
" !git pull\n",
" \n",
"else:\n",
+ " # Create the directory\n",
+ " os.mkdir('/content/kohya-trainer/deepdanbooru')\n",
+ "\n",
+ " # Clone the repository into the directory\n",
" !git clone https://github.com/KichangKim/DeepDanbooru deepdanbooru\n",
"\n",
+ "# Change the current working directory to the DeepDanbooru directory\n",
"%cd /content/kohya-trainer/deepdanbooru\n",
+ "\n",
+ "# Install the required libraries\n",
"!pip install -r requirements.txt\n",
- "!pip install .\n"
+ "!pip install ."
],
"metadata": {
- "id": "AsLO2-REM8Yd",
- "cellView": "form"
+ "cellView": "form",
+ "id": "03ycaH-RFASL"
},
"execution_count": null,
"outputs": []
@@ -323,98 +419,140 @@
{
"cell_type": "code",
"source": [
- "#@title Install DeepDanbooru3 Model Weight\n",
+ "#@title Install DeepDanbooru3 Weight\n",
"%cd /content/kohya-trainer/deepdanbooru\n",
+ "\n",
+ "# This code installs the DeepDanbooru3 model weight and unzips it to the specified directory.\n",
+ "\n",
"import os\n",
"import shutil\n",
"\n",
- "if os.path.isdir('/content/kohya-trainer/deepdanbooru/deepdanbooruv3'):\n",
- " pass\n",
- "else:\n",
- " !mkdir deepdanbooruv3\n",
+ "# Construct the path to the deepdanbooruv3 directory\n",
+ "deepdanbooruv3_dir = os.path.join('/content/kohya-trainer/deepdanbooru', 'deepdanbooruv3')\n",
"\n",
- "if os.path.isfile('/content/kohya-trainer/deepdanbooru/deepdanbooruv3/deepdanbooruv3.zip'):\n",
- " pass\n",
- "else:\n",
- " !wget -c https://github.com/KichangKim/deepdanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip -O deepdanbooruv3.zip\n",
+ "# Check if the directory already exists\n",
+ "if os.path.isdir(deepdanbooruv3_dir):\n",
+ " # If the directory exists, remove it\n",
+ " shutil.rmtree(deepdanbooruv3_dir)\n",
+ "\n",
+ "# Create the directory\n",
+ "try:\n",
+ " os.mkdir(deepdanbooruv3_dir)\n",
+ "except OSError as e:\n",
+ " print(f'Error: Unable to create the deepdanbooruv3 directory ({e})')\n",
+ " return\n",
"\n",
+ "# Construct the path to the deepdanbooruv3.zip file\n",
+ "deepdanbooruv3_zip = os.path.join(deepdanbooruv3_dir, 'deepdanbooruv3.zip')\n",
+ "\n",
+ "# Download the zip file to the deepdanbooruv3 directory\n",
+ "with open(deepdanbooruv3_zip, 'wb') as f:\n",
+ " !wget -c https://github.com/KichangKim/deepdanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip -O $f\n",
+ "\n",
+ "# Unzip the zip file directly to the deepdanbooruv3 directory\n",
+ "try:\n",
+ " shutil.unpack_archive(deepdanbooruv3_zip, deepdanbooruv3_dir)\n",
+ "except (shutil.ReadError, shutil.Error) as e:\n",
+ " print(f'Error: Unable to unzip the deepdanbooruv3.zip file ({e})')\n",
+ " return\n",
+ "\n",
+ "# Change the current working directory to the deepdanbooruv3 directory\n",
"%cd /content/kohya-trainer/deepdanbooru/deepdanbooruv3\n",
- "!unzip deepdanbooruv3.zip \n",
- "!rm -rf deepdanbooruv3.zip\n"
+ "\n",
+ "# Remove the deepdanbooruv3.zip file\n",
+ "os.remove(deepdanbooruv3_zip)\n"
],
"metadata": {
- "id": "p8Y1SWWwUO26",
- "cellView": "form"
+ "cellView": "form",
+ "id": "NTrETBw5QEYE"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "hibZK5NPTjZQ"
+ },
+ "outputs": [],
"source": [
- "#@title Start Autotagging\n",
+ "#@title Start Autotagger\n",
+ "\n",
"%cd /content/kohya-trainer/deepdanbooru/deepdanbooruv3\n",
- "!deepdanbooru evaluate /content/kohya-trainer/train_data \\\n",
+ "!deepdanbooru evaluate {train_data_dir} \\\n",
" --project-path /content/kohya-trainer/deepdanbooru/deepdanbooruv3 \\\n",
" --allow-folder \\\n",
- " --save-txt"
- ],
- "metadata": {
- "cellView": "form",
- "id": "NwxPfDeI3h2_"
- },
- "execution_count": null,
- "outputs": []
+ " --save-txt\n",
+ "\n",
+ "\n"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Create Metadata.json\n",
- "%cd /content/kohya-trainer\n",
- "!python merge_dd_tags_to_metadata.py train_data meta_cap_dd.json"
- ],
- "metadata": {
- "id": "hz2Cmlf2ay9w",
- "cellView": "form"
- },
"execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Tag cleaner\n",
- "%cd /content/kohya-trainer\n",
- "!python clean_captions_and_tags.py train_data meta_cap_dd.json meta_clean.json"
- ],
"metadata": {
"cellView": "form",
- "id": "Isog9VbN5Le3"
+ "id": "hz2Cmlf2ay9w"
},
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "source": [
+ "#@title Create Metadata.json\n",
+ "\n",
+ "\n",
+ "# Change the working directory\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "#@markdown ### Command-line Arguments\n",
+ "#@markdown The following command-line arguments are available:\n",
+ "#@markdown - `train_data_dir` : directory for training images\n",
+ "#@markdown - `out_json` : model path to load\n",
+ "#@markdown - `--in_json` : metadata file to input\n",
+ "#@markdown - `--debug` : debug mode\n",
+ "\n",
+ "#@markdown ### Define Parameter :\n",
+ "out_json = \"/content/kohya-trainer/meta_cap_dd.json\" #@param {'type':'string'}\n",
+ "\n",
+ "# Create the metadata file\n",
+ "!python merge_dd_tags_to_metadata.py \\\n",
+ " {train_data_dir} \\\n",
+ " {out_json}\n",
+ "\n",
+ "\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#Prepare Training"
- ],
"metadata": {
"id": "3gob9_OwTlwh"
- }
+ },
+ "source": [
+ "# Prepare Training"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "SoucgZQ6jgPQ"
+ },
+ "outputs": [],
"source": [
"#@title Install Pre-trained Model \n",
"%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
+ "import os\n",
+ "\n",
+ "# Check if directory exists\n",
+ "if not os.path.exists('checkpoint'):\n",
+ " # Create directory if it doesn't exist\n",
+ " os.makedirs('checkpoint')\n",
"\n",
"#@title Install Pre-trained Model \n",
"\n",
"installModels=[]\n",
"\n",
- "\n",
"#@markdown ### Available Model\n",
"#@markdown Select one of available pretrained model to download:\n",
"modelUrl = [\"\", \\\n",
@@ -433,7 +571,7 @@
" \"Anything-V3.0-pruned-fp32\", \\\n",
" \"Anything-V3.0-pruned\", \\\n",
" \"Stable-Diffusion-v1-4\", \\\n",
- " \"Stable-Diffusion-v1-5-pruned-emaonly\" \\\n",
+ " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n",
" \"Waifu-Diffusion-v1-3-fp32\"]\n",
"modelName = \"Animefull-final-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
"\n",
@@ -442,187 +580,235 @@
"customName = \"\" #@param {'type': 'string'}\n",
"customUrl = \"\"#@param {'type': 'string'}\n",
"\n",
- "if customName == \"\" or customUrl == \"\":\n",
- " pass\n",
- "else:\n",
+ "# Check if user has specified a custom model\n",
+ "if customName != \"\" and customUrl != \"\":\n",
+ " # Add custom model to list of models to install\n",
" installModels.append((customName, customUrl))\n",
"\n",
+ "# Check if user has selected a model\n",
"if modelName != \"\":\n",
- " # Map model to URL\n",
+ " # Map selected model to URL\n",
" installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n",
"\n",
"def install_aria():\n",
+ " # Install aria2 if it is not already installed\n",
" if not os.path.exists('/usr/bin/aria2c'):\n",
" !apt install -y -qq aria2\n",
"\n",
"def install(checkpoint_name, url):\n",
" if url.startswith(\"https://drive.google.com\"):\n",
+ " # Use gdown to download file from Google Drive\n",
" !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n",
" elif url.startswith(\"magnet:?\"):\n",
" install_aria()\n",
+ " # Use aria2c to download file from magnet link\n",
" !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n",
" else:\n",
" user_token = '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'\n",
" user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
+ " # Use wget to download file from URL\n",
" !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
"\n",
"def install_checkpoint():\n",
+ " # Iterate through list of models to install\n",
" for model in installModels:\n",
+ " # Call install function for each model\n",
" install(model[0], model[1])\n",
- "install_checkpoint()\n",
- "\n"
- ],
- "metadata": {
- "id": "SoucgZQ6jgPQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ "\n",
+ "# Call install_checkpoint function to download all models in the list\n",
+ "install_checkpoint()\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "IQwpRDVIbDB9"
+ },
+ "outputs": [],
"source": [
"#@title Emergency downgrade\n",
"#@markdown Tick this if you are facing issues on the cell below, such as high ram usage or cells not running\n",
"\n",
"diffuser_0_7_2 = True #@param {'type':'boolean'}\n",
"\n",
- "if diffuser_0_7_2 == True :\n",
+ "# Check if user wants to downgrade diffusers\n",
+ "if diffuser_0_7_2:\n",
+ " # Install diffusers 0.7.2\n",
" !pip install diffusers[torch]==0.7.2\n",
"else:\n",
+ " # Install latest version of diffusers\n",
" !pip install diffusers[torch]==0.9.0"
- ],
- "metadata": {
- "cellView": "form",
- "id": "IQwpRDVIbDB9"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "hhgatqF3leHJ"
+ },
+ "outputs": [],
"source": [
"#@title Aspect Ratio Bucketing\n",
- "%cd /content/kohya-trainer\n",
"\n",
+ "# Change working directory\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "#@markdown ### Command-line Arguments\n",
+ "#@markdown The following command-line arguments are available:\n",
+ "#@markdown * `train_data_dir`: directory for train images.\n",
+ "#@markdown * `in_json`: metadata file to input.\n",
+ "#@markdown * `out_json`: metadata file to output.\n",
+ "#@markdown * `model_name_or_path`: model name or path to encode latents.\n",
+ "#@markdown * `--v2`: load Stable Diffusion v2.0 model.\n",
+ "#@markdown * `--batch_size`: batch size in inference.\n",
+ "#@markdown * `--max_resolution`: max resolution in fine tuning (width,height).\n",
+ "#@markdown * `--min_bucket_reso`: minimum resolution for buckets.\n",
+ "#@markdown * `--max_bucket_reso`: maximum resolution for buckets.\n",
+ "#@markdown * `--mixed_precision`: use mixed precision.\n",
+ "\n",
+ "#@markdown ### Define parameters\n",
+ "in_json = \"/content/kohya-trainer/meta_cap_dd.json\" #@param {'type' : 'string'} \n",
+ "out_json = \"/content/kohya-trainer/meta_lat.json\" #@param {'type' : 'string'} \n",
"model_dir = \"/content/kohya-trainer/checkpoint/Anything-V3.0-pruned.ckpt\" #@param {'type' : 'string'} \n",
"batch_size = 4 #@param {'type':'integer'}\n",
"max_resolution = \"512,512\" #@param [\"512,512\", \"768,768\"] {allow-input: false}\n",
"mixed_precision = \"no\" #@param [\"no\", \"fp16\", \"bf16\"] {allow-input: false}\n",
"\n",
- "!python prepare_buckets_latents.py train_data meta_cap_dd.json meta_lat.json {model_dir} \\\n",
+ "# Run script to prepare buckets and latents\n",
+ "!python prepare_buckets_latents.py \\\n",
+ " {train_data_dir} \\\n",
+ " {in_json} \\\n",
+ " {out_json} \\\n",
+ " {model_dir} \\\n",
" --batch_size {batch_size} \\\n",
" --max_resolution {max_resolution} \\\n",
" --mixed_precision {mixed_precision}\n",
"\n",
+ "\n",
+ "\n",
" "
- ],
- "metadata": {
- "id": "hhgatqF3leHJ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "yHNbl3O_NSS0"
+ },
"source": [
"# Start Training\n",
"\n"
- ],
- "metadata": {
- "id": "yHNbl3O_NSS0"
- }
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "X_Rd3Eh07xlA"
+ },
+ "outputs": [],
"source": [
"#@title Training begin\n",
"num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n",
- "model_path =\"/content/herigaru/herigaru5k-pruned.ckpt\" #@param {'type':'string'}\n",
+ "pre_trained_model_path =\"/content/kohya-trainer/checkpoint/Animefull-final-pruned.ckpt\" #@param {'type':'string'}\n",
+ "meta_lat_json_dir = \"/content/kohya-trainer/meta_lat.json\" #@param {'type':'string'}\n",
+ "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type':'string'}\n",
"output_dir =\"/content/kohya-trainer/fine_tuned\" #@param {'type':'string'}\n",
+ "# resume_path = \"/content/kohya-trainer/last-state\" #@param {'type':'string'}\n",
"train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
"learning_rate =\"2e-6\" #@param {'type':'string'}\n",
"max_token_length = \"225\" #@param [\"150\", \"225\"] {allow-input: false}\n",
"clip_skip = 2 #@param {type: \"slider\", min: 1, max: 10}\n",
"mixed_precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n",
"max_train_steps = 5000 #@param {'type':'integer'}\n",
- "# save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n",
- "save_every_n_epochs = 100 #@param {'type':'integer'}\n",
+ "save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n",
+ "save_every_n_epochs = 50 #@param {'type':'integer'}\n",
"gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "dataset_repeats = 1 #@param {'type':'integer'}\n",
- " \n",
- "%cd /content/kohya-trainer\n",
+ "\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
"!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} fine_tune.py \\\n",
- " --pretrained_model_name_or_path={model_path} \\\n",
- " --in_json meta_lat.json \\\n",
- " --train_data_dir=train_data \\\n",
+ " --pretrained_model_name_or_path={pre_trained_model_path} \\\n",
+ " --in_json {meta_lat_json_dir} \\\n",
+ " --train_data_dir={train_data_dir} \\\n",
" --output_dir={output_dir} \\\n",
" --shuffle_caption \\\n",
- " --logging_dir=logs \\\n",
" --train_batch_size={train_batch_size} \\\n",
" --learning_rate={learning_rate} \\\n",
+ " --logging_dir=logs \\\n",
" --max_token_length={max_token_length} \\\n",
" --clip_skip={clip_skip} \\\n",
" --mixed_precision={mixed_precision} \\\n",
- " --max_train_steps={max_train_steps} \\\n",
+ " --max_train_steps={max_train_steps} \\\n",
" --use_8bit_adam \\\n",
" --xformers \\\n",
" --gradient_checkpointing \\\n",
- " --save_every_n_epochs={save_every_n_epochs} \\\n",
" --save_state \\\n",
" --gradient_accumulation_steps {gradient_accumulation_steps} \\\n",
- " --dataset_repeats {dataset_repeats} \n",
- " # --save_precision={save_precision} \n",
- " # --resume /content/kohya-trainer/checkpoint/last-state\n"
- ],
- "metadata": {
- "id": "X_Rd3Eh07xlA",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ " --save_precision={save_precision}\n",
+ " # --resume {resume_path} \\\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#Miscellaneous"
- ],
"metadata": {
"id": "vqfgyL-thgdw"
- }
+ },
+ "source": [
+ "# Miscellaneous"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Convert diffuser model to ckpt\n",
- "\n",
- "#@markdown If you're using diffuser weight, this cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
- "WEIGHTS_DIR = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
- "#@markdown Run conversion.\n",
- "ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n",
- "\n",
- "half_arg = \"\"\n",
- "#@markdown Whether to convert to fp16, takes half the space (2GB).\n",
- "fp16 = False #@param {type: \"boolean\"}\n",
- "if fp16:\n",
- " half_arg = \"--half\"\n",
- "!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n",
- "\n",
- "print(f\"[*] Converted ckpt saved at {ckpt_path}\")"
- ],
+ "execution_count": null,
"metadata": {
"cellView": "form",
"id": "nOhJCs3BeR_Q"
},
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "source": [
+ "#@title Convert diffuser model to ckpt (Optional)\n",
+ "\n",
+ "#@markdown If you're using diffuser weight, this cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "diffuser_weights_dir = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "use_fp16 = False #@param {type: \"boolean\"}\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Convert the diffuser weights to a checkpoint file\n",
+ "ckpt_path = diffuser_weights_dir + \"/model.ckpt\"\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "half_precision_arg = \"\"\n",
+ "if use_fp16:\n",
+ " # Use a more descriptive variable name\n",
+ " half_precision_arg = \"--half\"\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the conversion script\n",
+ "!python convert_diffusers_to_original_stable_diffusion.py --model_path $diffuser_weights_dir --checkpoint_path $ckpt_path $half_precision_arg\n",
+ "\n",
+ "# Use string formatting and a more descriptive variable name\n",
+ "print(f\"[*] Converted checkpoint saved at {ckpt_path}\")"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "LUOG7BzQVLKp"
+ },
+ "outputs": [],
"source": [
- "#@markdown ```\n",
- "#@markdown @lopho\n",
+ "#@title Model Pruner (Optional)\n",
+ "\n",
+ "#@markdown ```python\n",
"#@markdown usage: prune.py [-h] [-p] [-e] [-c] [-a] input output\n",
"#@markdown \n",
"#@markdown Prune a stable diffusion checkpoint\n",
@@ -638,267 +824,140 @@
"#@markdown -c, --no-clip strip CLIP weights\n",
"#@markdown -a, --no-vae strip VAE weights\n",
"#@markdown ```\n",
- "#@title Model Pruner\n",
+ "\n",
"#@markdown Do you want to Prune a model?\n",
"%cd /content/ \n",
"\n",
- "prune = True #@param {'type':'boolean'}\n",
+ "# Use a more descriptive variable name\n",
+ "should_prune = False #@param {'type':'boolean'}\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "source_model_path = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
"\n",
- "model_src = \"/content/herigaru/herigaru5k.ckpt\" #@param {'type' : 'string'}\n",
- "model_dst = \"/content/herigaru/herigaru5k-pruned.ckpt\" #@param {'type' : 'string'}\n",
+ "# Use a more descriptive variable name\n",
+ "pruned_model_path = \"/content/kohya-trainer/fine_tuned/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
"\n",
- "if prune == True:\n",
+ "if should_prune:\n",
" import os\n",
" if os.path.isfile('/content/prune.py'):\n",
" pass\n",
" else:\n",
+ " # Add a comment to explain what the code is doing\n",
+ " # Download the pruning script if it doesn't already exist\n",
" !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py\n",
"\n",
"\n",
- "!python3 prune.py {model_src} {model_dst}\n"
- ],
- "metadata": {
- "id": "LUOG7BzQVLKp",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= True #@param {'type':'boolean'}\n",
- "\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Huggingface_hub Integration"
- ],
- "metadata": {
- "id": "QtVP2le8PL2T"
- }
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the pruning script\n",
+ "!python3 prune.py {source_model_path} {pruned_model_path}"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "##Instruction:\n",
- "0. Of course you need a Huggingface Account first\n",
- "1. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "2. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells."
- ],
"metadata": {
- "id": "tbKgmh_AO5NG"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Login to Huggingface hub\n",
- "#@markdown Opt-out this cell when run all\n",
- "%cd /content/kohya-trainer\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "#@markdown Prepare your Huggingface token\n",
- "\n",
- "saved_token= \"save-your-write-token-here\" #@param {'type': 'string'}\n",
- "\n",
- "if opt_out == False:\n",
- " !pip install huggingface_hub\n",
- " \n",
- " from huggingface_hub import login\n",
- " login()\n",
- "\n"
- ],
- "metadata": {
- "id": "Da7awoqAPJ3a",
- "cellView": "form"
+ "id": "jypUkLWc48R_"
},
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
"source": [
- "##Commit trained model to Huggingface"
- ],
- "metadata": {
- "id": "jypUkLWc48R_"
- }
+ "## Commit trained model to Huggingface"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "###Instruction:\n",
- "0. Create huggingface repository for model\n",
- "1. Clone your model to this colab session\n",
- "2. Move these necessary file to your repository to save your trained model to huggingface\n",
- "\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- File `epoch-nnnnn.ckpt` and/or\n",
- "- File `last.ckpt`, \n",
- "\n",
- "4. Commit your model to huggingface"
- ],
"metadata": {
"id": "TvZgRSmKVSRw"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Model\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/herigaru\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "id": "182Law9oUiYN",
- "cellView": "form"
},
- "execution_count": null,
- "outputs": []
+ "source": [
+ "### To Commit models:\n",
+ "1. Create a huggingface repository for your model.\n",
+ "2. Clone your model to this Colab session.\n",
+ "3. Move the necessary files to your repository to save your trained model to huggingface. These files are located in `fine-tuned` folder:\n",
+ " - `epoch-nnnnn.ckpt` and/or\n",
+ " - `last.ckpt`\n",
+ "4. Commit your model to huggingface.\n",
+ "\n",
+ "### To Commit datasets:\n",
+ "1. Create a huggingface repository for your datasets.\n",
+ "2. Clone your datasets to this Colab session.\n",
+ "3. Move the necessary files to your repository so that you can resume training without rebuilding your dataset with this notebook:\n",
+ " - The `train_data` folder.\n",
+ " - The `meta_lat.json` file.\n",
+ " - The `last-state` folder.\n",
+ "4. Commit your datasets to huggingface.\n",
+ "\n"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"herigaru\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"upload newly finetuned model 5k\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{model_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " pass"
- ],
- "metadata": {
- "id": "87wG7QIZbtZE",
- "cellView": "form"
- },
"execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit dataset to huggingface"
- ],
"metadata": {
- "id": "olP2yaK3OKcr"
- }
- },
- {
- "cell_type": "markdown",
+ "cellView": "form",
+ "id": "182Law9oUiYN"
+ },
+ "outputs": [],
"source": [
- "###Instruction:\n",
- "0. Create huggingface repository for datasets\n",
- "1. Clone your datasets to this colab session\n",
- "2. Move these necessary file to your repository so that you can do resume training next time without rebuild your dataset with this notebook\n",
+ "#@title Clone Model or Datasets\n",
"\n",
- ">in `content/kohya-trainer`\n",
- "- Folder `train_data`\n",
- "- File `meta_cap_dd.json`\n",
- "- File `meta_lat.json`\n",
+ "#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- Folder `last-state`\n",
+ "#@markdown Type of item to clone (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
- "4. Commit your datasets to huggingface"
- ],
- "metadata": {
- "id": "jiSb0z2CVtc_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Dataset\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
+ "#@markdown Install or uninstall git lfs\n",
+ "install_git_lfs = False #@param {'type':'boolean'}\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " Repository_url = \"https://huggingface.co/datasets/Linaqruf/herigaru-tag\" #@param {'type': 'string'}\n",
+ " username = \"your-huggingface-username\" #@param {'type': 'string'}\n",
+ " model_repo = \"your-huggingface-model-repo\" #@param {'type': 'string'}\n",
+ " datasets_repo = \"your-huggingface-datasets-repo\" #@param {'type': 'string'}\n",
+ " \n",
+ " if type_of_item == \"model\":\n",
+ " Repository_url = f\"https://huggingface.co/{username}/{model_repo}\"\n",
+ " elif type_of_item == \"dataset\":\n",
+ " Repository_url = f\"https://huggingface.co/datasets/{username}/{datasets_repo}\"\n",
+ "\n",
+ " if install_git_lfs:\n",
+ " !git lfs install\n",
+ " else:\n",
+ " !git lfs uninstall\n",
+ "\n",
" !git clone {Repository_url}\n",
"else:\n",
" pass\n"
- ],
- "metadata": {
- "id": "QhL6UgqDOURK",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "87wG7QIZbtZE"
+ },
+ "outputs": [],
"source": [
- "#@title Commit to Huggingface\n",
+ "#@title Commit Model or Datasets to Huggingface\n",
+ "\n",
"#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- "opt_out= False #@param {'type':'boolean'}\n",
+ "#@markdown Type of item to commit (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"herigaru-tag\" #@param {'type': 'string'}\n",
+ " #@markdown Go to your model or dataset path\n",
+ " item_path = \"your-cloned-model-or-datasets-repo\" #@param {'type': 'string'}\n",
"\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
" #@markdown #Git Commit\n",
"\n",
" #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
+ " email = \"your-email\" #@param {'type': 'string'}\n",
+ " name = \"your-username\" #@param {'type': 'string'}\n",
" #@markdown Set **commit message**\n",
- " commit_m= \"upload 5k new datasets\" #@param {'type': 'string'}\n",
+ " commit_m = \"feat: upload 6 epochs model\" #@param {'type': 'string'}\n",
"\n",
- " %cd \"/content/{dataset_path}\"\n",
+ " %cd {item_path}\n",
" !git lfs install\n",
" !huggingface-cli lfs-enable-largefiles .\n",
" !git add .\n",
@@ -910,13 +969,24 @@
"\n",
"else:\n",
" pass"
- ],
- "metadata": {
- "id": "abHLg4I0Os5T",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
}
- ]
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
\ No newline at end of file
diff --git a/kohya-trainer-resume.ipynb b/kohya-trainer-resume.ipynb
index 507486e1..174aee4e 100644
--- a/kohya-trainer-resume.ipynb
+++ b/kohya-trainer-resume.ipynb
@@ -24,14 +24,14 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
"cell_type": "markdown",
"source": [
- "#Kohya Trainer V4 - VRAM 12GB [FOR RESUME TRAINING]\n",
- "###Notebook for resuming your latest training using [main notebook](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb)"
+ "# Kohya Trainer V6 - VRAM 12GB [FOR RESUME TRAINING]\n",
+ "### Notebook for resuming your latest training using [main notebook](https://colab.research.google.com/github/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb)"
],
"metadata": {
"id": "slgjeYgd6pWp"
@@ -40,11 +40,9 @@
{
"cell_type": "markdown",
"source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)
\n",
- "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)
\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb)\n",
- "\n",
- "\n"
+ "This notebook has been adapted for use in Google Colab based on the [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb). \n",
+ "This notebook was adapted by [Linaqruf](https://github.com/Linaqruf)\n",
+ "You can find the latest update to the notebook [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer-resume.ipynb).\n"
],
"metadata": {
"id": "gPgBR3KM6E-Z"
@@ -69,56 +67,106 @@
"outputs": [],
"source": [
"#@title Clone Kohya Trainer\n",
- "#@markdown Run this cell everytime you want to `!git pull` to get a lot of new optimizations and updates.\n",
+ "#@markdown Clone the Kohya Trainer repository from GitHub and check for updates\n",
+ "\n",
"%cd /content/\n",
"\n",
"import os\n",
"\n",
- "if os.path.isdir('/content/kohya-trainer'):\n",
- " %cd /content/kohya-trainer\n",
- " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
- " !git pull\n",
- " \n",
- "else:\n",
- " !git clone https://github.com/Linaqruf/kohya-trainer"
+ "def clone_kohya_trainer():\n",
+ " # Check if the directory already exists\n",
+ " if os.path.isdir('/content/kohya-trainer'):\n",
+ " %cd /content/kohya-trainer\n",
+ " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
+ " !git pull\n",
+ " else:\n",
+ " !git clone https://github.com/Linaqruf/kohya-trainer\n",
+ " \n",
+ "\n",
+ "# Clone or update the Kohya Trainer repository\n",
+ "clone_kohya_trainer()"
]
},
{
"cell_type": "code",
"source": [
"#@title Install Diffuser Fine Tuning\n",
+ "\n",
+ "# Change the current working directory to \"/content/kohya-trainer\".\n",
"%cd /content/kohya-trainer\n",
"\n",
+ "# Import `shutil` and `os` modules.\n",
"import shutil\n",
"import os\n",
"\n",
- "customVersion = []\n",
- "versionDir = [\"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v2.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v3.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v4.zip\"]\n",
- "versionList = [\"diffusers_fine_tuning_v1\", \\\n",
- " \"diffusers_fine_tuning_v2\", \\\n",
+ "# Initialize an empty list `custom_versions`.\n",
+ "custom_versions = []\n",
+ "\n",
+ "# Initialize a list `version_urls` containing URLs of different versions of the `diffusers_fine_tuning` file.\n",
+ "version_urls = [\"\",\\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v6/diffusers_fine_tuning_v6.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v5/diffusers_fine_tuning_v5.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v4/diffusers_fine_tuning_v4.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v3/diffusers_fine_tuning_v3.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v2/diffusers_fine_tuning_v2.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v1/diffusers_fine_tuning_v1.zip\"]\n",
+ "\n",
+ "# Initialize a list `version_names` containing names of different versions of the `diffusers_fine_tuning` file.\n",
+ "version_names = [\"latest_version\", \\\n",
+ " \"diffusers_fine_tuning_v6\", \\\n",
+ " \"diffusers_fine_tuning_v5\", \\\n",
+ " \"diffusers_fine_tuning_v4\", \\\n",
" \"diffusers_fine_tuning_v3\", \\\n",
- " \"diffusers_fine_tuning_v4\"]\n",
- "version = \"diffusers_fine_tuning_v3\" #@param [\"diffusers_fine_tuning_v1\",\"diffusers_fine_tuning_v2\",\"diffusers_fine_tuning_v3\",\"diffusers_fine_tuning_v4\"]\n",
- "\n",
- "customVersion.append((versionDir[versionList.index(version)]))\n",
- "\n",
- "for zip in customVersion:\n",
- " (zip[0])\n",
- "\n",
- "zip = \"\".join(zip)\n",
- "\n",
- "def unzip_function(dir):\n",
- " !unzip {dir} -d /content/kohya-trainer/\n",
- "\n",
- "def unzip_version():\n",
- " unzip_function(zip)\n",
+ " \"diffusers_fine_tuning_v2\", \\\n",
+ " \"diffusers_fine_tuning_v1\"]\n",
+ "\n",
+ "# Initialize a variable `selected_version` to the selected version of the `diffusers_fine_tuning` file.\n",
+ "selected_version = \"latest_version\" #@param [\"latest_version\", \"diffusers_fine_tuning_v6\", \"diffusers_fine_tuning_v5\", \"diffusers_fine_tuning_v4\", \"diffusers_fine_tuning_v3\", \"diffusers_fine_tuning_v2\", \"diffusers_fine_tuning_v1\"]\n",
+ "\n",
+ "# Append a tuple to `custom_versions`, containing `selected_version` and the corresponding item\n",
+ "# in `version_urls`.\n",
+ "custom_versions.append((selected_version, version_urls[version_names.index(selected_version)]))\n",
+ "\n",
+ "# Define `download` function to download a file from the given URL and save it with\n",
+ "# the given name.\n",
+ "def download(name, url):\n",
+ " !wget -c \"{url}\" -O /content/{name}.zip\n",
+ "\n",
+ "# Define `unzip` function to unzip a file with the given name to a specified\n",
+ "# directory.\n",
+ "def unzip(name):\n",
+ " !unzip /content/{name}.zip -d /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "# Define `download_version` function to download and unzip a file from `custom_versions`,\n",
+ "# unless `selected_version` is \"latest_version\".\n",
+ "def download_version():\n",
+ " if selected_version != \"latest_version\":\n",
+ " for zip in custom_versions:\n",
+ " download(zip[0], zip[1])\n",
+ "\n",
+ " # Rename the existing `diffuser_fine_tuning` directory to the `tmp` directory and delete any existing `tmp` directory.\n",
+ " if os.path.exists(\"/content/kohya-trainer/tmp\"):\n",
+ " shutil.rmtree(\"/content/kohya-trainer/tmp\")\n",
+ " os.rename(\"/content/kohya-trainer/diffuser_fine_tuning\", \"/content/kohya-trainer/tmp\")\n",
+ "\n",
+ " # Create a new empty `diffuser_fine_tuning` directory.\n",
+ " os.makedirs(\"/content/kohya-trainer/diffuser_fine_tuning\")\n",
+ " \n",
+ " # Unzip the downloaded file to the new `diffuser_fine_tuning` directory.\n",
+ " unzip(zip[0])\n",
+ " \n",
+ " # Delete the downloaded and unzipped file.\n",
+ " os.remove(\"/content/{}.zip\".format(zip[0]))\n",
+ " \n",
+ " # Inform the user that the existing `diffuser_fine_tuning` directory has been renamed to the `tmp` directory\n",
+ " # and a new empty `diffuser_fine_tuning` directory has been created.\n",
+ " print(\"Renamed existing 'diffuser_fine_tuning' directory to 'tmp' directory and created new empty 'diffuser_fine_tuning' directory.\")\n",
+ " else:\n",
+ " # Do nothing if `selected_version` is \"latest_version\".\n",
+ " pass\n",
"\n",
- "unzip_version()\n",
- "# if version == \"diffusers_fine_tuning_v1\":\n",
- "# !unzip /content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip -d /content/kohya-trainer/"
+ "# Call `download_version` function.\n",
+ "download_version()"
],
"metadata": {
"cellView": "form",
@@ -133,40 +181,23 @@
"#@title Installing Dependencies\n",
"%cd /content/kohya-trainer\n",
"\n",
- "Install_Python_3_9_6 = False #@param{'type':'boolean'}\n",
+ "def install_dependencies():\n",
+ " #@markdown Install required Python packages\n",
+ " !pip install --upgrade -r script/requirements.txt\n",
+ " !pip install -U gallery-dl\n",
+ " !pip install tensorflow\n",
+ " !pip install huggingface_hub\n",
"\n",
- "if Install_Python_3_9_6 == True:\n",
- " #install python 3.9\n",
- " !sudo apt-get update -y\n",
- " !sudo apt-get install python3.9\n",
+ " # Install xformers\n",
+ " !pip install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl\n",
"\n",
- " #change alternatives\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2\n",
"\n",
- " #check python version\n",
- " !python --version\n",
- " #3.9.6\n",
- " !sudo apt-get install python3.9-distutils && wget https://bootstrap.pypa.io/get-pip.py && python get-pip.py\n",
- "\n",
- "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
- "\n",
- "if os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
- " pass\n",
- "else:\n",
+ "# Install convert_diffusers_to_original_stable_diffusion.py script\n",
+ "if not os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
" !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
"\n",
- "!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113\n",
- "!pip install --upgrade -r script/requirements.txt\n",
- "!pip install -U gallery-dl\n",
- "!pip install tensorflow\n",
- "!pip install accelerate==0.14.0\n",
- "\n",
- "#install xformers\n",
- "if Install_Python_3_9_6 == True:\n",
- " !pip install -U -I --no-deps https://github.com/daswer123/stable-diffusion-colab/raw/main/xformers%20prebuild/T4/python39/xformers-0.0.14.dev0-cp39-cp39-linux_x86_64.whl\n",
- "else:\n",
- " !pip install -U -I --no-deps https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl\n"
+ "# Install dependencies\n",
+ "install_dependencies()"
],
"metadata": {
"id": "WNn0g1pnHfk5",
@@ -223,31 +254,55 @@
"cell_type": "code",
"source": [
"#@title Login to Huggingface hub\n",
- "#@markdown #Instruction:\n",
- "#@markdown 0. Of course you need a Huggingface Account first\n",
- "#@markdown 1. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "#@markdown 2. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells.\n",
"\n",
- "#@markdown Opt-out this cell when run all\n",
+ "#@markdown ## Instructions:\n",
+ "#@markdown 1. Of course, you need a Huggingface account first.\n",
+ "#@markdown 2. To create a huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
+ "#@markdown 3. By default, all cells below are marked as `opt-out`, so you need to uncheck them if you want to run the cells.\n",
+ "\n",
"%cd /content/kohya-trainer\n",
- "from IPython.core.display import HTML\n",
"\n",
- "opt_out= False #@param {'type':'boolean'}\n",
+ "from huggingface_hub import login\n",
+ "login()\n",
+ "\n"
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "zf7ZJ4f1KXiz"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#@title Clone Datasets Repo From Huggingface\n",
"\n",
- "#@markdown Prepare your Huggingface token\n",
+ "#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- "saved_token= \"save-your-write-token-here\" #@param {'type': 'string'}\n",
+ "#@markdown Install or uninstall git lfs\n",
+ "install_git_lfs = True #@param {'type':'boolean'}\n",
"\n",
"if opt_out == False:\n",
- " !pip install huggingface_hub\n",
- " \n",
- " from huggingface_hub import login\n",
- " login()\n",
- "\n"
+ " %cd /content\n",
+ " username = \"your-huggingface-username\" #@param {'type': 'string'}\n",
+ " datasets_repo = \"your-huggingface-datasets-repo\" #@param {'type': 'string'}\n",
+ " \n",
+ " Repository_url = f\"https://huggingface.co/datasets/{username}/{datasets_repo}\"\n",
+ "\n",
+ " if install_git_lfs:\n",
+ " !git lfs install\n",
+ " else:\n",
+ " !git lfs uninstall\n",
+ "\n",
+ " !git clone {Repository_url}\n",
+ "else:\n",
+ " pass\n"
],
"metadata": {
"cellView": "form",
- "id": "zf7ZJ4f1KXiz"
+ "id": "JBP1EEItSegB"
},
"execution_count": null,
"outputs": []
@@ -255,17 +310,18 @@
{
"cell_type": "code",
"source": [
- "#@title Clone Datasets with Git LFS installed\n",
- "%cd /content\n",
+ "#@title Mount Google Drive\n",
"\n",
- "dataset_url = \"https://huggingface.co/datasets/Linaqruf/herigaru-tag\" #@param {'type': 'string'}\n",
- "!git lfs install\n",
- "!git clone {dataset_url}\n",
- "\n"
+ "from google.colab import drive\n",
+ "\n",
+ "mount_drive = True #@param {'type':'boolean'}\n",
+ "\n",
+ "if mount_drive:\n",
+ " drive.mount('/content/drive')"
],
"metadata": {
"cellView": "form",
- "id": "JBP1EEItSegB"
+ "id": "s1KgygUBm8XL"
},
"execution_count": null,
"outputs": []
@@ -284,13 +340,17 @@
"source": [
"#@title Install Pre-trained Model \n",
"%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
+ "import os\n",
+ "\n",
+ "# Check if directory exists\n",
+ "if not os.path.exists('checkpoint'):\n",
+ " # Create directory if it doesn't exist\n",
+ " os.makedirs('checkpoint')\n",
"\n",
"#@title Install Pre-trained Model \n",
"\n",
"installModels=[]\n",
"\n",
- "\n",
"#@markdown ### Available Model\n",
"#@markdown Select one of available pretrained model to download:\n",
"modelUrl = [\"\", \\\n",
@@ -309,44 +369,52 @@
" \"Anything-V3.0-pruned-fp32\", \\\n",
" \"Anything-V3.0-pruned\", \\\n",
" \"Stable-Diffusion-v1-4\", \\\n",
- " \"Stable-Diffusion-v1-5-pruned-emaonly\" \\\n",
+ " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n",
" \"Waifu-Diffusion-v1-3-fp32\"]\n",
- "modelName = \"\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
+ "modelName = \"Animefull-final-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
"\n",
"#@markdown ### Custom model\n",
"#@markdown The model URL should be a direct download link.\n",
- "customName = \"herigaru5k\" #@param {'type': 'string'}\n",
- "customUrl = \"https://huggingface.co/Linaqruf/herigaru/resolve/main/herigaru5k.ckpt\"#@param {'type': 'string'}\n",
+ "customName = \"\" #@param {'type': 'string'}\n",
+ "customUrl = \"\"#@param {'type': 'string'}\n",
"\n",
- "if customName == \"\" or customUrl == \"\":\n",
- " pass\n",
- "else:\n",
+ "# Check if user has specified a custom model\n",
+ "if customName != \"\" and customUrl != \"\":\n",
+ " # Add custom model to list of models to install\n",
" installModels.append((customName, customUrl))\n",
"\n",
+ "# Check if user has selected a model\n",
"if modelName != \"\":\n",
- " # Map model to URL\n",
+ " # Map selected model to URL\n",
" installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n",
"\n",
"def install_aria():\n",
+ " # Install aria2 if it is not already installed\n",
" if not os.path.exists('/usr/bin/aria2c'):\n",
" !apt install -y -qq aria2\n",
"\n",
"def install(checkpoint_name, url):\n",
" if url.startswith(\"https://drive.google.com\"):\n",
+ " # Use gdown to download file from Google Drive\n",
" !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n",
" elif url.startswith(\"magnet:?\"):\n",
" install_aria()\n",
+ " # Use aria2c to download file from magnet link\n",
" !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n",
" else:\n",
" user_token = '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'\n",
" user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
+ " # Use wget to download file from URL\n",
" !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
"\n",
"def install_checkpoint():\n",
+ " # Iterate through list of models to install\n",
" for model in installModels:\n",
+ " # Call install function for each model\n",
" install(model[0], model[1])\n",
- "install_checkpoint()\n",
- "\n"
+ "\n",
+ "# Call install_checkpoint function to download all models in the list\n",
+ "install_checkpoint()\n"
],
"metadata": {
"id": "SoucgZQ6jgPQ",
@@ -363,9 +431,12 @@
"\n",
"diffuser_0_7_2 = True #@param {'type':'boolean'}\n",
"\n",
- "if diffuser_0_7_2 == True :\n",
+ "# Check if user wants to downgrade diffusers\n",
+ "if diffuser_0_7_2:\n",
+ " # Install diffusers 0.7.2\n",
" !pip install diffusers[torch]==0.7.2\n",
"else:\n",
+ " # Install latest version of diffusers\n",
" !pip install diffusers[torch]==0.9.0"
],
"metadata": {
@@ -390,11 +461,11 @@
"source": [
"#@title Training begin\n",
"num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n",
- "pre_trained_model_path =\"/content/herigaru/herigaru10k.ckpt\" #@param {'type':'string'}\n",
- "meta_lat_json_dir = \"/content/herigaru-tag/meta_lat.json\" #@param {'type':'string'}\n",
- "train_data_dir = \"/content/herigaru-tag/train_data\" #@param {'type':'string'}\n",
+ "pre_trained_model_path =\"/content/kohya-trainer/checkpoint/Animefull-final-pruned.ckpt\" #@param {'type':'string'}\n",
+ "meta_lat_json_dir = \"/content/kohya-trainer/meta_lat.json\" #@param {'type':'string'}\n",
+ "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type':'string'}\n",
"output_dir =\"/content/kohya-trainer/fine_tuned\" #@param {'type':'string'}\n",
- "resume_path = \"/content/herigaru-tag/last-state\" #@param {'type':'string'}\n",
+ "resume_path = \"/content/kohya-trainer/last-state\" #@param {'type':'string'}\n",
"train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
"learning_rate =\"2e-6\" #@param {'type':'string'}\n",
"max_token_length = \"225\" #@param [\"150\", \"225\"] {allow-input: false}\n",
@@ -405,8 +476,7 @@
"save_every_n_epochs = 50 #@param {'type':'integer'}\n",
"gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
"\n",
- "\n",
- "%cd /content/kohya-trainer\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
"!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} fine_tune.py \\\n",
" --pretrained_model_name_or_path={pre_trained_model_path} \\\n",
" --in_json {meta_lat_json_dir} \\\n",
@@ -419,14 +489,14 @@
" --max_token_length={max_token_length} \\\n",
" --clip_skip={clip_skip} \\\n",
" --mixed_precision={mixed_precision} \\\n",
- " --max_train_steps={max_train_steps} \\\n",
+ " --max_train_steps={max_train_steps} \\\n",
" --use_8bit_adam \\\n",
" --xformers \\\n",
" --gradient_checkpointing \\\n",
" --save_state \\\n",
- " --gradient_accumulation_steps {gradient_accumulation_steps}\n",
- " # --save_precision={save_precision}\n",
- " # --resume {resume_path} \\\n"
+ " --gradient_accumulation_steps {gradient_accumulation_steps} \\\n",
+ " --save_precision={save_precision} \\\n",
+ " --resume {resume_path} \n"
],
"metadata": {
"id": "X_Rd3Eh07xlA",
@@ -447,21 +517,32 @@
{
"cell_type": "code",
"source": [
- "#@title Convert diffuser model to ckpt\n",
+ "#@title Convert diffuser model to ckpt (Optional)\n",
"\n",
"#@markdown If you're using diffuser weight, this cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
- "WEIGHTS_DIR = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
- "#@markdown Run conversion.\n",
- "ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n",
- "\n",
- "half_arg = \"\"\n",
- "#@markdown Whether to convert to fp16, takes half the space (2GB).\n",
- "fp16 = False #@param {type: \"boolean\"}\n",
- "if fp16:\n",
- " half_arg = \"--half\"\n",
- "!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n",
- "\n",
- "print(f\"[*] Converted ckpt saved at {ckpt_path}\")"
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "diffuser_weights_dir = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "use_fp16 = False #@param {type: \"boolean\"}\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Convert the diffuser weights to a checkpoint file\n",
+ "ckpt_path = diffuser_weights_dir + \"/model.ckpt\"\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "half_precision_arg = \"\"\n",
+ "if use_fp16:\n",
+ " # Use a more descriptive variable name\n",
+ " half_precision_arg = \"--half\"\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the conversion script\n",
+ "!python convert_diffusers_to_original_stable_diffusion.py --model_path $diffuser_weights_dir --checkpoint_path $ckpt_path $half_precision_arg\n",
+ "\n",
+ "# Use string formatting and a more descriptive variable name\n",
+ "print(f\"[*] Converted checkpoint saved at {ckpt_path}\")"
],
"metadata": {
"cellView": "form",
@@ -473,8 +554,9 @@
{
"cell_type": "code",
"source": [
- "#@markdown ```\n",
- "#@markdown @lopho\n",
+ "#@title Model Pruner (Optional)\n",
+ "\n",
+ "#@markdown ```python\n",
"#@markdown usage: prune.py [-h] [-p] [-e] [-c] [-a] input output\n",
"#@markdown \n",
"#@markdown Prune a stable diffusion checkpoint\n",
@@ -490,24 +572,32 @@
"#@markdown -c, --no-clip strip CLIP weights\n",
"#@markdown -a, --no-vae strip VAE weights\n",
"#@markdown ```\n",
- "#@title Model Pruner\n",
+ "\n",
"#@markdown Do you want to Prune a model?\n",
"%cd /content/ \n",
"\n",
- "prune = True #@param {'type':'boolean'}\n",
+ "# Use a more descriptive variable name\n",
+ "should_prune = False #@param {'type':'boolean'}\n",
"\n",
- "model_src = \"/content/kohya-trainer/checkpoint/herigaru5k.ckpt\" #@param {'type' : 'string'}\n",
- "model_dst = \"/content/kohya-trainer/checkpoint/herigaru5k.ckpt\" #@param {'type' : 'string'}\n",
+ "# Use a more descriptive variable name\n",
+ "source_model_path = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
"\n",
- "if prune == True:\n",
+ "# Use a more descriptive variable name\n",
+ "pruned_model_path = \"/content/kohya-trainer/fine_tuned/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
+ "\n",
+ "if should_prune:\n",
" import os\n",
" if os.path.isfile('/content/prune.py'):\n",
" pass\n",
" else:\n",
+ " # Add a comment to explain what the code is doing\n",
+ " # Download the pruning script if it doesn't already exist\n",
" !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py\n",
"\n",
"\n",
- "!python3 prune.py {model_src} {model_dst}\n"
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the pruning script\n",
+ "!python3 prune.py {source_model_path} {pruned_model_path}"
],
"metadata": {
"id": "LUOG7BzQVLKp",
@@ -516,23 +606,6 @@
"execution_count": null,
"outputs": []
},
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= True #@param {'type':'boolean'}\n",
- "\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
{
"cell_type": "code",
"source": [
@@ -550,175 +623,110 @@
},
{
"cell_type": "markdown",
- "source": [
- "##Commit trained model to Huggingface"
- ],
"metadata": {
"id": "jypUkLWc48R_"
- }
+ },
+ "source": [
+ "## Commit trained model to Huggingface"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "###Instruction:\n",
- "0. Create huggingface repository for model\n",
- "1. Clone your model to this colab session\n",
- "2. Move these necessary file to your repository to save your trained model to huggingface\n",
- "\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- File `epoch-nnnnn.ckpt` and/or\n",
- "- File `last.ckpt`, \n",
- "\n",
- "4. Commit your model to huggingface"
- ],
"metadata": {
"id": "TvZgRSmKVSRw"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Model\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/herigaru\" #@param {'type': 'string'}\n",
- " !git lfs uninstall\n",
- " !git clone {Repository_url}\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "id": "182Law9oUiYN",
- "cellView": "form"
},
- "execution_count": null,
- "outputs": []
+ "source": [
+ "### To Commit models:\n",
+ "1. Create a huggingface repository for your model.\n",
+ "2. Clone your model to this Colab session.\n",
+ "3. Move the necessary files to your repository to save your trained model to huggingface. These files are located in `fine-tuned` folder:\n",
+ " - `epoch-nnnnn.ckpt` and/or\n",
+ " - `last.ckpt`\n",
+ "4. Commit your model to huggingface.\n",
+ "\n",
+ "### To Commit datasets:\n",
+ "1. Create a huggingface repository for your datasets.\n",
+ "2. Clone your datasets to this Colab session.\n",
+ "3. Move the necessary files to your repository so that you can resume training without rebuilding your dataset with this notebook:\n",
+ " - The `train_data` folder.\n",
+ " - The `meta_lat.json` file.\n",
+ " - The `last-state` folder.\n",
+ "4. Commit your datasets to huggingface.\n",
+ "\n"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"herigaru\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"upload 15k model\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{model_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " pass"
- ],
- "metadata": {
- "id": "87wG7QIZbtZE",
- "cellView": "form"
- },
"execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit dataset to huggingface"
- ],
"metadata": {
- "id": "olP2yaK3OKcr"
- }
- },
- {
- "cell_type": "markdown",
+ "cellView": "form",
+ "id": "182Law9oUiYN"
+ },
+ "outputs": [],
"source": [
- "###Instruction:\n",
- "0. Create huggingface repository for datasets\n",
- "1. Clone your datasets to this colab session\n",
- "2. Move these necessary file to your repository so that you can do resume training next time without rebuild your dataset with this notebook\n",
+ "#@title Clone Model or Datasets\n",
"\n",
- ">in `content/kohya-trainer`\n",
- "- Folder `train_data`\n",
- "- File `meta_cap_dd.json`\n",
- "- File `meta_lat.json`\n",
+ "#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- Folder `last-state`\n",
+ "#@markdown Type of item to clone (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
- "4. Commit your datasets to huggingface"
- ],
- "metadata": {
- "id": "jiSb0z2CVtc_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Dataset\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
+ "#@markdown Install or uninstall git lfs\n",
+ "install_git_lfs = False #@param {'type':'boolean'}\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " Repository_url = \"https://huggingface.co/datasets/Linaqruf/herigaru-tag\" #@param {'type': 'string'}\n",
- " !git lfs uninstall\n",
+ " username = \"your-huggingface-username\" #@param {'type': 'string'}\n",
+ " model_repo = \"your-huggingface-model-repo\" #@param {'type': 'string'}\n",
+ " datasets_repo = \"your-huggingface-datasets-repo\" #@param {'type': 'string'}\n",
+ " \n",
+ " if type_of_item == \"model\":\n",
+ " Repository_url = f\"https://huggingface.co/{username}/{model_repo}\"\n",
+ " elif type_of_item == \"dataset\":\n",
+ " Repository_url = f\"https://huggingface.co/datasets/{username}/{datasets_repo}\"\n",
+ "\n",
+ " if install_git_lfs:\n",
+ " !git lfs install\n",
+ " else:\n",
+ " !git lfs uninstall\n",
+ "\n",
" !git clone {Repository_url}\n",
"else:\n",
" pass\n"
- ],
- "metadata": {
- "id": "QhL6UgqDOURK",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "87wG7QIZbtZE"
+ },
+ "outputs": [],
"source": [
- "#@title Commit to Huggingface\n",
+ "#@title Commit Model or Datasets to Huggingface\n",
+ "\n",
"#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- "opt_out= False #@param {'type':'boolean'}\n",
+ "#@markdown Type of item to commit (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"herigaru-tag\" #@param {'type': 'string'}\n",
+ " #@markdown Go to your model or dataset path\n",
+ " item_path = \"your-cloned-model-or-datasets-repo\" #@param {'type': 'string'}\n",
"\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
" #@markdown #Git Commit\n",
"\n",
" #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"furqanil.taqwa@gmail.com\" #@param {'type': 'string'}\n",
- " name= \"Linaqruf\" #@param {'type': 'string'}\n",
+ " email = \"your-email\" #@param {'type': 'string'}\n",
+ " name = \"your-username\" #@param {'type': 'string'}\n",
" #@markdown Set **commit message**\n",
- " commit_m= \"upload 15k state\" #@param {'type': 'string'}\n",
+ " commit_m = \"feat: upload 6 epochs model\" #@param {'type': 'string'}\n",
"\n",
- " %cd \"/content/{dataset_path}\"\n",
+ " %cd {item_path}\n",
" !git lfs install\n",
" !huggingface-cli lfs-enable-largefiles .\n",
" !git add .\n",
@@ -730,13 +738,7 @@
"\n",
"else:\n",
" pass"
- ],
- "metadata": {
- "id": "abHLg4I0Os5T",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
}
]
}
\ No newline at end of file
diff --git a/kohya-trainer.ipynb b/kohya-trainer.ipynb
index 8ee9fff9..49e11ed0 100644
--- a/kohya-trainer.ipynb
+++ b/kohya-trainer.ipynb
@@ -1,21 +1,4 @@
{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "colab": {
- "provenance": [],
- "include_colab_link": true
- },
- "kernelspec": {
- "name": "python3",
- "display_name": "Python 3"
- },
- "language_info": {
- "name": "python"
- },
- "gpuClass": "standard",
- "accelerator": "GPU"
- },
"cells": [
{
"cell_type": "markdown",
@@ -24,184 +7,196 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
"cell_type": "markdown",
- "source": [
- "#Kohya Trainer V4 - VRAM 12GB\n",
- "###Best way to fine-tune Stable Diffusion model for peeps who didn't have good GPU"
- ],
"metadata": {
"id": "slgjeYgd6pWp"
- }
+ },
+ "source": [
+ "# Kohya Trainer V6 - VRAM 12GB\n",
+ "### The Best Way for People Without Good GPUs to Fine-Tune the Stable Diffusion Model"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "Adapted to Google Colab based on [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb)
\n",
- "Adapted to Google Colab by [Linaqruf](https://github.com/Linaqruf)
\n",
- "You can find latest notebook update [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb)\n",
- "\n",
- "\n"
- ],
"metadata": {
"id": "gPgBR3KM6E-Z"
- }
- },
- {
- "cell_type": "markdown",
+ },
"source": [
- "##Deprecated cell:\n",
- "- Move trained model to cloned repository\n",
- "- Move datasets to cloned repository\n",
- "\n",
- "##What's Changes?\n",
- "- Moved description to [README.md](https://github.com/Linaqruf/kohya-trainer/)\n",
- "- Update `xformers` precompiled wheel to `0.0.14.dev0-cp38` \n",
- "- Update `Diffusers v0.9.0`\n",
- "- Update fine-tuning script to V4, now support SD 2.0 fine-tuning and load diffuser model as pre-trained model\n",
- "- You can choose which script version you want to use\n",
- "- Added option to install `Python 3.9.6`\n",
- "- `gallery-dl` now support Gelbooru scraping\n",
- "- Added datasets cleaner cell to automatically remove unnecessary extension in `train_data` folder\n",
- "- Added emergency downgrade cell to `Diffusers v0.7.2` if you're facing issue like high ram usage, note that you can't do SD2.0 training in `v0.7.2`\n",
- "- Added option to convert `diffuser` model to `ckpt`\n",
- "- Changed model pruner script with `prune.py` by [lopho](https://github.com/lopho/stable-diffusion-prune)\n",
- "\n"
- ],
- "metadata": {
- "id": "xB9gJk4ywuao"
- }
+ "This notebook has been adapted for use in Google Colab based on the [Kohya Guide](https://note.com/kohya_ss/n/nbf7ce8d80f29#c9d7ee61-5779-4436-b4e6-9053741c46bb). \n",
+ "This notebook was adapted by [Linaqruf](https://github.com/Linaqruf)\n",
+ "You can find the latest update to the notebook [here](https://github.com/Linaqruf/kohya-trainer/blob/main/kohya-trainer.ipynb).\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#Install Kohya Trainer"
- ],
"metadata": {
"id": "tTVqCAgSmie4"
- }
+ },
+ "source": [
+ "# Install Kohya Trainer"
+ ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "_u3q60di584x",
- "cellView": "form"
+ "cellView": "form",
+ "id": "_u3q60di584x"
},
"outputs": [],
"source": [
"#@title Clone Kohya Trainer\n",
- "#@markdown Run this cell everytime you want to `!git pull` to get a lot of new optimizations and updates.\n",
+ "#@markdown Clone the Kohya Trainer repository from GitHub and check for updates\n",
+ "\n",
"%cd /content/\n",
"\n",
"import os\n",
"\n",
- "if os.path.isdir('/content/kohya-trainer'):\n",
- " %cd /content/kohya-trainer\n",
- " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
- " !git pull\n",
- " \n",
- "else:\n",
- " !git clone https://github.com/Linaqruf/kohya-trainer"
+ "def clone_kohya_trainer():\n",
+ " # Check if the directory already exists\n",
+ " if os.path.isdir('/content/kohya-trainer'):\n",
+ " %cd /content/kohya-trainer\n",
+ " print(\"This folder already exists, will do a !git pull instead\\n\")\n",
+ " !git pull\n",
+ " else:\n",
+ " !git clone https://github.com/Linaqruf/kohya-trainer\n",
+ " \n",
+ "\n",
+ "# Clone or update the Kohya Trainer repository\n",
+ "clone_kohya_trainer()"
]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "nj8fNQZNESyT"
+ },
+ "outputs": [],
"source": [
"#@title Install Diffuser Fine Tuning\n",
+ "\n",
+ "# Change the current working directory to \"/content/kohya-trainer\".\n",
"%cd /content/kohya-trainer\n",
"\n",
+ "# Import `shutil` and `os` modules.\n",
"import shutil\n",
"import os\n",
"\n",
- "customVersion = []\n",
- "versionDir = [\"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v2.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v3.zip\", \\\n",
- " \"/content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v4.zip\"]\n",
- "versionList = [\"diffusers_fine_tuning_v1\", \\\n",
- " \"diffusers_fine_tuning_v2\", \\\n",
+ "# Initialize an empty list `custom_versions`.\n",
+ "custom_versions = []\n",
+ "\n",
+ "# Initialize a list `version_urls` containing URLs of different versions of the `diffusers_fine_tuning` file.\n",
+ "version_urls = [\"\",\\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v6/diffusers_fine_tuning_v6.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v5/diffusers_fine_tuning_v5.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v4/diffusers_fine_tuning_v4.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v3/diffusers_fine_tuning_v3.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v2/diffusers_fine_tuning_v2.zip\", \\\n",
+ " \"https://github.com/Linaqruf/kohya-trainer/releases/download/v1/diffusers_fine_tuning_v1.zip\"]\n",
+ "\n",
+ "# Initialize a list `version_names` containing names of different versions of the `diffusers_fine_tuning` file.\n",
+ "version_names = [\"latest_version\", \\\n",
+ " \"diffusers_fine_tuning_v6\", \\\n",
+ " \"diffusers_fine_tuning_v5\", \\\n",
+ " \"diffusers_fine_tuning_v4\", \\\n",
" \"diffusers_fine_tuning_v3\", \\\n",
- " \"diffusers_fine_tuning_v4\"]\n",
- "version = \"diffusers_fine_tuning_v4\" #@param [\"diffusers_fine_tuning_v1\",\"diffusers_fine_tuning_v2\",\"diffusers_fine_tuning_v3\",\"diffusers_fine_tuning_v4\"]\n",
- "\n",
- "customVersion.append((versionDir[versionList.index(version)]))\n",
- "\n",
- "for zip in customVersion:\n",
- " (zip[0])\n",
- "\n",
- "zip = \"\".join(zip)\n",
- "\n",
- "def unzip_function(dir):\n",
- " !unzip {dir} -d /content/kohya-trainer/\n",
- "\n",
- "def unzip_version():\n",
- " unzip_function(zip)\n",
+ " \"diffusers_fine_tuning_v2\", \\\n",
+ " \"diffusers_fine_tuning_v1\"]\n",
+ "\n",
+ "# Initialize a variable `selected_version` to the selected version of the `diffusers_fine_tuning` file.\n",
+ "selected_version = \"latest_version\" #@param [\"latest_version\", \"diffusers_fine_tuning_v6\", \"diffusers_fine_tuning_v5\", \"diffusers_fine_tuning_v4\", \"diffusers_fine_tuning_v3\", \"diffusers_fine_tuning_v2\", \"diffusers_fine_tuning_v1\"]\n",
+ "\n",
+ "# Append a tuple to `custom_versions`, containing `selected_version` and the corresponding item\n",
+ "# in `version_urls`.\n",
+ "custom_versions.append((selected_version, version_urls[version_names.index(selected_version)]))\n",
+ "\n",
+ "# Define `download` function to download a file from the given URL and save it with\n",
+ "# the given name.\n",
+ "def download(name, url):\n",
+ " !wget -c \"{url}\" -O /content/{name}.zip\n",
+ "\n",
+ "# Define `unzip` function to unzip a file with the given name to a specified\n",
+ "# directory.\n",
+ "def unzip(name):\n",
+ " !unzip /content/{name}.zip -d /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "# Define `download_version` function to download and unzip a file from `custom_versions`,\n",
+ "# unless `selected_version` is \"latest_version\".\n",
+ "def download_version():\n",
+ " if selected_version != \"latest_version\":\n",
+ " for zip in custom_versions:\n",
+ " download(zip[0], zip[1])\n",
+ "\n",
+ " # Rename the existing `diffuser_fine_tuning` directory to the `tmp` directory and delete any existing `tmp` directory.\n",
+ " if os.path.exists(\"/content/kohya-trainer/tmp\"):\n",
+ " shutil.rmtree(\"/content/kohya-trainer/tmp\")\n",
+ " os.rename(\"/content/kohya-trainer/diffuser_fine_tuning\", \"/content/kohya-trainer/tmp\")\n",
+ "\n",
+ " # Create a new empty `diffuser_fine_tuning` directory.\n",
+ " os.makedirs(\"/content/kohya-trainer/diffuser_fine_tuning\")\n",
+ " \n",
+ " # Unzip the downloaded file to the new `diffuser_fine_tuning` directory.\n",
+ " unzip(zip[0])\n",
+ " \n",
+ " # Delete the downloaded and unzipped file.\n",
+ " os.remove(\"/content/{}.zip\".format(zip[0]))\n",
+ " \n",
+ " # Inform the user that the existing `diffuser_fine_tuning` directory has been renamed to the `tmp` directory\n",
+ " # and a new empty `diffuser_fine_tuning` directory has been created.\n",
+ " print(\"Renamed existing 'diffuser_fine_tuning' directory to 'tmp' directory and created new empty 'diffuser_fine_tuning' directory.\")\n",
+ " else:\n",
+ " # Do nothing if `selected_version` is \"latest_version\".\n",
+ " pass\n",
"\n",
- "unzip_version()\n",
- "# if version == \"diffusers_fine_tuning_v1\":\n",
- "# !unzip /content/kohya-trainer/diffuser_fine_tuning/diffusers_fine_tuning_v1.zip -d /content/kohya-trainer/"
- ],
- "metadata": {
- "cellView": "form",
- "id": "nj8fNQZNESyT"
- },
- "execution_count": null,
- "outputs": []
+ "# Call `download_version` function.\n",
+ "download_version()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "WNn0g1pnHfk5"
+ },
+ "outputs": [],
"source": [
"#@title Installing Dependencies\n",
"%cd /content/kohya-trainer\n",
"\n",
- "Install_Python_3_9_6 = False #@param{'type':'boolean'}\n",
+ "def install_dependencies():\n",
+ " #@markdown Install required Python packages\n",
+ " !pip install --upgrade -r script/requirements.txt\n",
+ " !pip install -U gallery-dl\n",
+ " !pip install tensorflow\n",
+ " !pip install huggingface_hub\n",
"\n",
- "if Install_Python_3_9_6 == True:\n",
- " #install python 3.9\n",
- " !sudo apt-get update -y\n",
- " !sudo apt-get install python3.9\n",
+ " # Install xformers\n",
+ " !pip install -U -I --no-deps https://github.com/camenduru/stable-diffusion-webui-colab/releases/download/0.0.15/xformers-0.0.15.dev0+189828c.d20221207-cp38-cp38-linux_x86_64.whl\n",
"\n",
- " #change alternatives\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1\n",
- " !sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2\n",
"\n",
- " #check python version\n",
- " !python --version\n",
- " #3.9.6\n",
- " !sudo apt-get install python3.9-distutils && wget https://bootstrap.pypa.io/get-pip.py && python get-pip.py\n",
- "\n",
- "!wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
- "\n",
- "if os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
- " pass\n",
- "else:\n",
+ "# Install convert_diffusers_to_original_stable_diffusion.py script\n",
+ "if not os.path.isfile('/content/kohya-trainer/convert_diffusers_to_original_stable_diffusion.py'):\n",
" !wget -q https://github.com/ShivamShrirao/diffusers/raw/main/scripts/convert_diffusers_to_original_stable_diffusion.py\n",
"\n",
- "!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113\n",
- "!pip install --upgrade -r script/requirements.txt\n",
- "!pip install -U gallery-dl\n",
- "!pip install tensorflow\n",
- "!pip install accelerate==0.14.0\n",
- "\n",
- "#install xformers\n",
- "if Install_Python_3_9_6 == True:\n",
- " !pip install -U -I --no-deps https://github.com/daswer123/stable-diffusion-colab/raw/main/xformers%20prebuild/T4/python39/xformers-0.0.14.dev0-cp39-cp39-linux_x86_64.whl\n",
- "else:\n",
- " !pip install -U -I --no-deps https://github.com/TheLastBen/fast-stable-diffusion/raw/main/precompiled/T4/xformers-0.0.13.dev0-py3-none-any.whl\n"
- ],
- "metadata": {
- "id": "WNn0g1pnHfk5",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ "# Install dependencies\n",
+ "install_dependencies()"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "VZOXwDv3utpx"
+ },
+ "outputs": [],
"source": [
"#@title Set config for `!Accelerate`\n",
"#@markdown #Hint\n",
@@ -215,113 +210,183 @@
"%cd /content/kohya-trainer\n",
"\n",
"!accelerate config"
- ],
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "M0fzmhtywk_u"
+ },
+ "source": [
+ "# Prepare Cloud Storage (Huggingface/GDrive)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {
"cellView": "form",
- "id": "VZOXwDv3utpx"
+ "id": "cwIJdhEcwk_u"
},
+ "outputs": [],
+ "source": [
+ "#@title Login to Huggingface hub\n",
+ "\n",
+ "#@markdown ## Instructions:\n",
+ "#@markdown 1. Of course, you need a Huggingface account first.\n",
+ "#@markdown 2. To create a huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
+ "#@markdown 3. By default, all cells below are marked as `opt-out`, so you need to uncheck them if you want to run the cells.\n",
+ "\n",
+ "%cd /content/kohya-trainer\n",
+ "\n",
+ "from huggingface_hub import login\n",
+ "login()\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
- "outputs": []
+ "metadata": {
+ "cellView": "form",
+ "id": "jVgHUUK_wk_v"
+ },
+ "outputs": [],
+ "source": [
+ "#@title Mount Google Drive\n",
+ "\n",
+ "from google.colab import drive\n",
+ "\n",
+ "mount_drive = True #@param {'type':'boolean'}\n",
+ "\n",
+ "if mount_drive:\n",
+ " drive.mount('/content/drive')"
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "En9UUwGNMRMM"
+ },
+ "source": [
+ "# Collecting datasets\n",
+ "\n",
+ "You can either upload your datasets to this notebook or use the image scraper below to bulk download images from Danbooru.\n",
+ "\n",
+ "If you want to use your own datasets, you can upload to colab `local files`.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
"source": [
- "#Collecting datasets\n",
- "You can either upload your datasets to this notebook or use image scraper below to bulk download images from danbooru.\n",
+ "#@title Define Train Data\n",
+ "#@markdown Define where your train data will be located. This cell will also create a folder based on your input. \n",
+ "#@markdown This folder will be used as the target folder for scraping, tagging, bucketing, and training in the next cell.\n",
+ "\n",
+ "import os\n",
"\n",
- "If you want to use your own datasets, make sure to put them in a folder titled `train_data` in `content/kohya-trainer`\n",
+ "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type' : 'string'}\n",
"\n",
- "This is to make the training process easier because the folder that will be used for training is in `content/kohya-trainer/train-data`."
+ "if not os.path.exists(train_data_dir):\n",
+ " os.makedirs(train_data_dir)\n",
+ "else:\n",
+ " print(f\"{train_data_dir} already exists\\n\")\n",
+ "\n",
+ "print(f\"Your train data directory : {train_data_dir}\")\n"
],
"metadata": {
- "id": "En9UUwGNMRMM"
- }
+ "cellView": "form",
+ "id": "nXNk0NOwzWw4"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Kt1GzntK_apb"
+ },
+ "outputs": [],
"source": [
"#@title Booru Scraper\n",
- "#@markdown **How this work?**\n",
+ "#@markdown Use gallery-dl to scrape images from a booru site using the specified tags\n",
"\n",
- "#@markdown By using **gallery-dl** we can scrap or bulk download images on Internet, on this notebook we will scrap images from popular booru sites using tag1 and tag2 as target scraping.\n",
"%cd /content\n",
"\n",
+ "# Set configuration options\n",
"booru = \"Danbooru\" #@param [\"\", \"Danbooru\", \"Gelbooru\"]\n",
- "tag = \"herigaru_(fvgyvr000)\" #@param {type: \"string\"}\n",
+ "tag1 = \"hito_komoru\" #@param {type: \"string\"}\n",
"tag2 = \"\" #@param {type: \"string\"}\n",
"\n",
+ "# Construct the search query\n",
"if tag2 != \"\":\n",
- " tag = tag + \"+\" + tag2\n",
+ " tags = tag1 + \"+\" + tag2\n",
"else:\n",
- " tag = tag\n",
+ " tags = tag1\n",
"\n",
- "output_dir = \"/content/kohya-trainer/train_data\"\n",
- "\n",
- "if booru == \"Danbooru\":\n",
- " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tag}\" -D {output_dir}\n",
- "elif booru == \"Gelbooru\":\n",
- " !gallery-dl \"https://gelbooru.com/index.php?page=post&s=list&tags={tag}\" -D {output_dir}\n",
+ "# Scrape images from the specified booru site using the given tags\n",
+ "if booru.lower() == \"danbooru\":\n",
+ " !gallery-dl \"https://danbooru.donmai.us/posts?tags={tags}\" -D {train_data_dir}\n",
+ "elif booru.lower() == \"gelbooru\":\n",
+ " !gallery-dl \"https://gelbooru.com/index.php?page=post&s=list&tags={tags}\" -D {train_data_dir}\n",
"else:\n",
- " pass\n",
- "\n",
- "\n",
- "#@markdown The output directory will be on /content/kohya-trainer/train_data. We also will use this folder as target folder for training next step.\n"
- ],
- "metadata": {
- "id": "Kt1GzntK_apb",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ " print(f\"Unknown booru site: {booru}\")\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "Jz2emq6vWnPu"
+ },
+ "outputs": [],
"source": [
"#@title Datasets cleaner\n",
- "#@markdown This will delete unnecessary file and unsupported media like `.bin`, `.mp4`, `.webm`, and `.gif`\n",
- "import os\n",
+ "#@markdown This will delete unnecessary files and unsupported media like `.mp4`, `.webm`, and `.gif`\n",
"\n",
- "dir_name = \"/content/kohya-trainer/train_data\" #@param {'type' : 'string'}\n",
- "test = os.listdir(dir_name)\n",
+ "%cd /content\n",
"\n",
- "for item in test:\n",
- " if item.endswith(\".mp4\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
+ "import os\n",
+ "test = os.listdir(train_data_dir)\n",
"\n",
- "for item in test:\n",
- " if item.endswith(\".webm\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
+ "# List of supported file types\n",
+ "supported_types = [\".jpg\", \".jpeg\", \".png\"]\n",
"\n",
+ "# Iterate over all files in the directory\n",
"for item in test:\n",
- " if item.endswith(\".gif\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
- " \n",
- "for item in test:\n",
- " if item.endswith(\".webp\"):\n",
- " os.remove(os.path.join(dir_name, item))\n",
- "\n"
- ],
- "metadata": {
- "cellView": "form",
- "id": "Jz2emq6vWnPu"
- },
- "execution_count": null,
- "outputs": []
+ " # Extract the file extension from the file name\n",
+ " file_ext = os.path.splitext(item)[1]\n",
+ " # If the file extension is not in the list of supported types, delete the file\n",
+ " if file_ext not in supported_types:\n",
+ " # Print a message indicating the name of the file being deleted\n",
+ " print(f\"Deleting file {item} from {train_data_dir}\")\n",
+ " # Delete the file\n",
+ " os.remove(os.path.join(train_data_dir, item))\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#`(NEW)` Waifu Diffusion 1.4 Autotagger"
- ],
"metadata": {
"id": "SoPUJaTpTusz"
- }
+ },
+ "source": [
+ "# `(NEW)` Waifu Diffusion 1.4 Autotagger"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "WDSlAEHzT2Im"
+ },
+ "outputs": [],
"source": [
"#@title Download Weight\n",
- "%cd /content/kohya-trainer/\n",
+ "%cd /content/kohya-trainer\n",
"\n",
"import os\n",
"import shutil\n",
@@ -332,89 +397,130 @@
" !wget -c --header={user_header} {url} -O /content/kohya-trainer/wd14tagger-weight/{weight}\n",
"\n",
"def download_weight():\n",
- " !mkdir /content/kohya-trainer/wd14tagger-weight/\n",
- " huggingface_dl(\"https://huggingface.co/Linaqruf/personal_backup/resolve/main/wd14tagger-weight/wd14Tagger.zip\", \"wd14Tagger.zip\")\n",
+ " # Remove the weight directory if it exists\n",
+ " weight_dir = '/content/kohya-trainer/wd14tagger-weight/'\n",
+ " if os.path.exists(weight_dir):\n",
+ " shutil.rmtree(weight_dir)\n",
+ "\n",
+ " # Create the weight directory\n",
+ " os.mkdir(weight_dir)\n",
+ "\n",
+ " # Download the weight file from the specified URL\n",
+ " weight_url = \"https://huggingface.co/Linaqruf/personal_backup/resolve/main/wd14tagger-weight/wd14Tagger.zip\"\n",
+ " huggingface_dl(weight_url, \"wd14Tagger.zip\")\n",
" \n",
+ " # Extract the weight file from the zip archive\n",
" !unzip /content/kohya-trainer/wd14tagger-weight/wd14Tagger.zip -d /content/kohya-trainer/wd14tagger-weight\n",
"\n",
- " # Destination path \n",
- " destination = '/content/kohya-trainer/wd14tagger-weight'\n",
+ " # Move the weight file to the weight directory\n",
+ " shutil.move(\"script/tag_images_by_wd14_tagger.py\", weight_dir)\n",
"\n",
- " if os.path.isfile('/content/kohya-trainer/tag_images_by_wd14_tagger.py'):\n",
- " # Move the content of \n",
- " # source to destination \n",
- " shutil.move(\"tag_images_by_wd14_tagger.py\", destination) \n",
- " else:\n",
- " pass\n",
+ " # Delete the zip file after it has been extracted\n",
+ " os.remove('/content/kohya-trainer/wd14tagger-weight/wd14Tagger.zip')\n",
"\n",
- "download_weight()"
- ],
- "metadata": {
- "cellView": "form",
- "id": "WDSlAEHzT2Im"
- },
- "execution_count": null,
- "outputs": []
+ "download_weight()\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "hibZK5NPTjZQ"
+ },
+ "outputs": [],
"source": [
"#@title Start Autotagger\n",
+ "\n",
+ "# Change the working directory to the weight directory\n",
"%cd /content/kohya-trainer/wd14tagger-weight\n",
- "!python tag_images_by_wd14_tagger.py --batch_size 4 /content/kohya-trainer/train_data\n",
"\n",
- "#@markdown Args list:\n",
- "#@markdown - `--train_data_dir` : directory for training images\n",
+ "#@markdown ### Command-line Arguments\n",
+ "#@markdown The following command-line arguments are available:\n",
+ "#@markdown - `train_data_dir` : directory for training images\n",
"#@markdown - `--model` : model path to load\n",
"#@markdown - `--tag_csv` : csv file for tag\n",
"#@markdown - `--thresh` : threshold of confidence to add a tag\n",
"#@markdown - `--batch_size` : batch size in inference\n",
- "#@markdown - `--model` : model path to load\n",
"#@markdown - `--caption_extension` : extension of caption file\n",
- "#@markdown - `--debug` : debug mode\n"
- ],
- "metadata": {
- "id": "hibZK5NPTjZQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ "#@markdown - `--debug` : debug mode\n",
+ "\n",
+ "#@markdown ### Define parameter:\n",
+ "batch_size = 4 #@param {'type':'integer'}\n",
+ "caption_extension = \".txt\" #@param [\".txt\",\".caption\"]\n",
+ "\n",
+ "!python tag_images_by_wd14_tagger.py \\\n",
+ " {train_data_dir} \\\n",
+ " --batch_size {batch_size} \\\n",
+ " --caption_extension .txt\n",
+ "\n",
+ "\n"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Create Metadata.json\n",
- "%cd /content/kohya-trainer\n",
- "!python merge_dd_tags_to_metadata.py train_data meta_cap_dd.json"
- ],
+ "execution_count": null,
"metadata": {
- "id": "hz2Cmlf2ay9w",
- "cellView": "form"
+ "cellView": "form",
+ "id": "hz2Cmlf2ay9w"
},
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "source": [
+ "#@title Create Metadata.json\n",
+ "\n",
+ "\n",
+ "# Change the working directory\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "#@markdown ### Command-line Arguments\n",
+ "#@markdown The following command-line arguments are available:\n",
+ "#@markdown - `train_data_dir` : directory for training images\n",
+ "#@markdown - `out_json` : model path to load\n",
+ "#@markdown - `--in_json` : metadata file to input\n",
+ "#@markdown - `--debug` : debug mode\n",
+ "\n",
+ "#@markdown ### Define Parameter :\n",
+ "out_json = \"/content/kohya-trainer/meta_cap_dd.json\" #@param {'type':'string'}\n",
+ "\n",
+ "# Create the metadata file\n",
+ "!python merge_dd_tags_to_metadata.py \\\n",
+ " {train_data_dir} \\\n",
+ " {out_json}\n",
+ "\n",
+ "\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#Prepare Training"
- ],
"metadata": {
"id": "3gob9_OwTlwh"
- }
+ },
+ "source": [
+ "# Prepare Training"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "SoucgZQ6jgPQ"
+ },
+ "outputs": [],
"source": [
"#@title Install Pre-trained Model \n",
"%cd /content/kohya-trainer\n",
- "!mkdir checkpoint\n",
+ "import os\n",
+ "\n",
+ "# Check if directory exists\n",
+ "if not os.path.exists('checkpoint'):\n",
+ " # Create directory if it doesn't exist\n",
+ " os.makedirs('checkpoint')\n",
"\n",
"#@title Install Pre-trained Model \n",
"\n",
"installModels=[]\n",
"\n",
- "\n",
"#@markdown ### Available Model\n",
"#@markdown Select one of available pretrained model to download:\n",
"modelUrl = [\"\", \\\n",
@@ -433,195 +539,244 @@
" \"Anything-V3.0-pruned-fp32\", \\\n",
" \"Anything-V3.0-pruned\", \\\n",
" \"Stable-Diffusion-v1-4\", \\\n",
- " \"Stable-Diffusion-v1-5-pruned-emaonly\" \\\n",
+ " \"Stable-Diffusion-v1-5-pruned-emaonly\", \\\n",
" \"Waifu-Diffusion-v1-3-fp32\"]\n",
- "modelName = \"Anything-V3.0-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
+ "modelName = \"Animefull-final-pruned\" #@param [\"\", \"Animefull-final-pruned\", \"Animesfw-final-pruned\", \"Anything-V3.0-pruned-fp16\", \"Anything-V3.0-pruned-fp32\", \"Anything-V3.0-pruned\", \"Stable-Diffusion-v1-4\", \"Stable-Diffusion-v1-5-pruned-emaonly\", \"Waifu-Diffusion-v1-3-fp32\"]\n",
"\n",
"#@markdown ### Custom model\n",
"#@markdown The model URL should be a direct download link.\n",
"customName = \"\" #@param {'type': 'string'}\n",
"customUrl = \"\"#@param {'type': 'string'}\n",
"\n",
- "if customName == \"\" or customUrl == \"\":\n",
- " pass\n",
- "else:\n",
+ "# Check if user has specified a custom model\n",
+ "if customName != \"\" and customUrl != \"\":\n",
+ " # Add custom model to list of models to install\n",
" installModels.append((customName, customUrl))\n",
"\n",
+ "# Check if user has selected a model\n",
"if modelName != \"\":\n",
- " # Map model to URL\n",
+ " # Map selected model to URL\n",
" installModels.append((modelName, modelUrl[modelList.index(modelName)]))\n",
"\n",
"def install_aria():\n",
+ " # Install aria2 if it is not already installed\n",
" if not os.path.exists('/usr/bin/aria2c'):\n",
" !apt install -y -qq aria2\n",
"\n",
"def install(checkpoint_name, url):\n",
" if url.startswith(\"https://drive.google.com\"):\n",
+ " # Use gdown to download file from Google Drive\n",
" !gdown --fuzzy -O \"/content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\" \"{url}\"\n",
" elif url.startswith(\"magnet:?\"):\n",
" install_aria()\n",
+ " # Use aria2c to download file from magnet link\n",
" !aria2c --summary-interval=10 -c -x 10 -k 1M -s 10 -o /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt \"{url}\"\n",
" else:\n",
" user_token = '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'\n",
" user_header = f\"\\\"Authorization: Bearer {user_token}\\\"\"\n",
+ " # Use wget to download file from URL\n",
" !wget -c --header={user_header} \"{url}\" -O /content/kohya-trainer/checkpoint/{checkpoint_name}.ckpt\n",
"\n",
"def install_checkpoint():\n",
+ " # Iterate through list of models to install\n",
" for model in installModels:\n",
+ " # Call install function for each model\n",
" install(model[0], model[1])\n",
- "install_checkpoint()\n",
- "\n"
- ],
- "metadata": {
- "id": "SoucgZQ6jgPQ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ "\n",
+ "# Call install_checkpoint function to download all models in the list\n",
+ "install_checkpoint()\n"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "IQwpRDVIbDB9"
+ },
+ "outputs": [],
"source": [
"#@title Emergency downgrade\n",
"#@markdown Tick this if you are facing issues on the cell below, such as high ram usage or cells not running\n",
"\n",
"diffuser_0_7_2 = True #@param {'type':'boolean'}\n",
"\n",
- "if diffuser_0_7_2 == True :\n",
+ "# Check if user wants to downgrade diffusers\n",
+ "if diffuser_0_7_2:\n",
+ " # Install diffusers 0.7.2\n",
" !pip install diffusers[torch]==0.7.2\n",
"else:\n",
+ " # Install latest version of diffusers\n",
" !pip install diffusers[torch]==0.9.0"
- ],
- "metadata": {
- "cellView": "form",
- "id": "IQwpRDVIbDB9"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "hhgatqF3leHJ"
+ },
+ "outputs": [],
"source": [
"#@title Aspect Ratio Bucketing\n",
- "%cd /content/kohya-trainer\n",
"\n",
+ "# Change working directory\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
+ "\n",
+ "#@markdown ### Command-line Arguments\n",
+ "#@markdown The following command-line arguments are available:\n",
+ "#@markdown * `train_data_dir`: directory for train images.\n",
+ "#@markdown * `in_json`: metadata file to input.\n",
+ "#@markdown * `out_json`: metadata file to output.\n",
+ "#@markdown * `model_name_or_path`: model name or path to encode latents.\n",
+ "#@markdown * `--v2`: load Stable Diffusion v2.0 model.\n",
+ "#@markdown * `--batch_size`: batch size in inference.\n",
+ "#@markdown * `--max_resolution`: max resolution in fine tuning (width,height).\n",
+ "#@markdown * `--min_bucket_reso`: minimum resolution for buckets.\n",
+ "#@markdown * `--max_bucket_reso`: maximum resolution for buckets.\n",
+ "#@markdown * `--mixed_precision`: use mixed precision.\n",
+ "\n",
+ "#@markdown ### Define parameters\n",
+ "in_json = \"/content/kohya-trainer/meta_cap_dd.json\" #@param {'type' : 'string'} \n",
+ "out_json = \"/content/kohya-trainer/meta_lat.json\" #@param {'type' : 'string'} \n",
"model_dir = \"/content/kohya-trainer/checkpoint/Anything-V3.0-pruned.ckpt\" #@param {'type' : 'string'} \n",
"batch_size = 4 #@param {'type':'integer'}\n",
"max_resolution = \"512,512\" #@param [\"512,512\", \"768,768\"] {allow-input: false}\n",
"mixed_precision = \"no\" #@param [\"no\", \"fp16\", \"bf16\"] {allow-input: false}\n",
"\n",
- "!python prepare_buckets_latents.py train_data meta_cap_dd.json meta_lat.json {model_dir} \\\n",
+ "# Run script to prepare buckets and latents\n",
+ "!python prepare_buckets_latents.py \\\n",
+ " {train_data_dir} \\\n",
+ " {in_json} \\\n",
+ " {out_json} \\\n",
+ " {model_dir} \\\n",
" --batch_size {batch_size} \\\n",
" --max_resolution {max_resolution} \\\n",
" --mixed_precision {mixed_precision}\n",
"\n",
+ "\n",
+ "\n",
" "
- ],
- "metadata": {
- "id": "hhgatqF3leHJ",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "markdown",
+ "metadata": {
+ "id": "yHNbl3O_NSS0"
+ },
"source": [
"# Start Training\n",
"\n"
- ],
- "metadata": {
- "id": "yHNbl3O_NSS0"
- }
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "X_Rd3Eh07xlA"
+ },
+ "outputs": [],
"source": [
"#@title Training begin\n",
"num_cpu_threads_per_process = 8 #@param {'type':'integer'}\n",
- "model_path =\"/content/kohya-trainer/checkpoint/Anything-V3.0-pruned.ckpt\" #@param {'type':'string'}\n",
+ "pre_trained_model_path =\"/content/kohya-trainer/checkpoint/Animefull-final-pruned.ckpt\" #@param {'type':'string'}\n",
+ "meta_lat_json_dir = \"/content/kohya-trainer/meta_lat.json\" #@param {'type':'string'}\n",
+ "train_data_dir = \"/content/kohya-trainer/train_data\" #@param {'type':'string'}\n",
"output_dir =\"/content/kohya-trainer/fine_tuned\" #@param {'type':'string'}\n",
+ "# resume_path = \"/content/kohya-trainer/last-state\" #@param {'type':'string'}\n",
"train_batch_size = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
"learning_rate =\"2e-6\" #@param {'type':'string'}\n",
"max_token_length = \"225\" #@param [\"150\", \"225\"] {allow-input: false}\n",
"clip_skip = 2 #@param {type: \"slider\", min: 1, max: 10}\n",
"mixed_precision = \"fp16\" #@param [\"fp16\", \"bf16\"] {allow-input: false}\n",
- "max_train_steps = 10000 #@param {'type':'integer'}\n",
- "# save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n",
+ "max_train_steps = 5000 #@param {'type':'integer'}\n",
+ "save_precision = \"fp16\" #@param [\"float\", \"fp16\", \"bf16\"] {allow-input: false}\n",
"save_every_n_epochs = 50 #@param {'type':'integer'}\n",
"gradient_accumulation_steps = 1 #@param {type: \"slider\", min: 1, max: 10}\n",
- "dataset_repeats = 1 #@param {'type':'integer'}\n",
- " \n",
- "%cd /content/kohya-trainer\n",
+ "\n",
+ "%cd /content/kohya-trainer/diffuser_fine_tuning\n",
"!accelerate launch --num_cpu_threads_per_process {num_cpu_threads_per_process} fine_tune.py \\\n",
- " --pretrained_model_name_or_path={model_path} \\\n",
- " --in_json meta_lat.json \\\n",
- " --train_data_dir=train_data \\\n",
+ " --pretrained_model_name_or_path={pre_trained_model_path} \\\n",
+ " --in_json {meta_lat_json_dir} \\\n",
+ " --train_data_dir={train_data_dir} \\\n",
" --output_dir={output_dir} \\\n",
" --shuffle_caption \\\n",
" --train_batch_size={train_batch_size} \\\n",
" --learning_rate={learning_rate} \\\n",
+ " --logging_dir=logs \\\n",
" --max_token_length={max_token_length} \\\n",
" --clip_skip={clip_skip} \\\n",
" --mixed_precision={mixed_precision} \\\n",
- " --max_train_steps={max_train_steps} \\\n",
+ " --max_train_steps={max_train_steps} \\\n",
" --use_8bit_adam \\\n",
" --xformers \\\n",
" --gradient_checkpointing \\\n",
- " --save_every_n_epochs={save_every_n_epochs} \\\n",
" --save_state \\\n",
" --gradient_accumulation_steps {gradient_accumulation_steps} \\\n",
- " --dataset_repeats {dataset_repeats} \n",
- " # --save_precision={save_precision} \n",
- " # --resume /content/kohya-trainer/checkpoint/last-state\n"
- ],
- "metadata": {
- "id": "X_Rd3Eh07xlA",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ " --save_precision={save_precision}\n",
+ " # --resume {resume_path} \\\n"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "#Miscellaneous"
- ],
"metadata": {
"id": "vqfgyL-thgdw"
- }
+ },
+ "source": [
+ "# Miscellaneous"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Convert diffuser model to ckpt\n",
- "\n",
- "#@markdown If you're using diffuser weight, this cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
- "WEIGHTS_DIR = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
- "#@markdown Run conversion.\n",
- "ckpt_path = WEIGHTS_DIR + \"/model.ckpt\"\n",
- "\n",
- "half_arg = \"\"\n",
- "#@markdown Whether to convert to fp16, takes half the space (2GB).\n",
- "fp16 = False #@param {type: \"boolean\"}\n",
- "if fp16:\n",
- " half_arg = \"--half\"\n",
- "!python convert_diffusers_to_original_stable_diffusion.py --model_path $WEIGHTS_DIR --checkpoint_path $ckpt_path $half_arg\n",
- "\n",
- "print(f\"[*] Converted ckpt saved at {ckpt_path}\")"
- ],
+ "execution_count": null,
"metadata": {
"cellView": "form",
"id": "nOhJCs3BeR_Q"
},
- "execution_count": null,
- "outputs": []
+ "outputs": [],
+ "source": [
+ "#@title Convert diffuser model to ckpt (Optional)\n",
+ "\n",
+ "#@markdown If you're using diffuser weight, this cell will convert output weight to checkpoint file so it can be used in Web UI like Auto1111's\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "diffuser_weights_dir = \"/content/drive/MyDrive/fine_tuned/last\" #@param {'type':'string'}\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "use_fp16 = False #@param {type: \"boolean\"}\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Convert the diffuser weights to a checkpoint file\n",
+ "ckpt_path = diffuser_weights_dir + \"/model.ckpt\"\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "half_precision_arg = \"\"\n",
+ "if use_fp16:\n",
+ " # Use a more descriptive variable name\n",
+ " half_precision_arg = \"--half\"\n",
+ "\n",
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the conversion script\n",
+ "!python convert_diffusers_to_original_stable_diffusion.py --model_path $diffuser_weights_dir --checkpoint_path $ckpt_path $half_precision_arg\n",
+ "\n",
+ "# Use string formatting and a more descriptive variable name\n",
+ "print(f\"[*] Converted checkpoint saved at {ckpt_path}\")"
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "LUOG7BzQVLKp"
+ },
+ "outputs": [],
"source": [
- "#@markdown ```\n",
- "#@markdown @lopho\n",
+ "#@title Model Pruner (Optional)\n",
+ "\n",
+ "#@markdown ```python\n",
"#@markdown usage: prune.py [-h] [-p] [-e] [-c] [-a] input output\n",
"#@markdown \n",
"#@markdown Prune a stable diffusion checkpoint\n",
@@ -637,267 +792,140 @@
"#@markdown -c, --no-clip strip CLIP weights\n",
"#@markdown -a, --no-vae strip VAE weights\n",
"#@markdown ```\n",
- "#@title Model Pruner\n",
+ "\n",
"#@markdown Do you want to Prune a model?\n",
"%cd /content/ \n",
"\n",
- "prune = True #@param {'type':'boolean'}\n",
+ "# Use a more descriptive variable name\n",
+ "should_prune = False #@param {'type':'boolean'}\n",
+ "\n",
+ "# Use a more descriptive variable name\n",
+ "source_model_path = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
"\n",
- "model_src = \"/content/kohya-trainer/fine_tuned/last.ckpt\" #@param {'type' : 'string'}\n",
- "model_dst = \"/content/kohya-trainer/fine_tuned/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
+ "# Use a more descriptive variable name\n",
+ "pruned_model_path = \"/content/kohya-trainer/fine_tuned/last-pruned.ckpt\" #@param {'type' : 'string'}\n",
"\n",
- "if prune == True:\n",
+ "if should_prune:\n",
" import os\n",
" if os.path.isfile('/content/prune.py'):\n",
" pass\n",
" else:\n",
+ " # Add a comment to explain what the code is doing\n",
+ " # Download the pruning script if it doesn't already exist\n",
" !wget https://raw.githubusercontent.com/lopho/stable-diffusion-prune/main/prune.py\n",
"\n",
"\n",
- "!python3 prune.py -p {model_src} {model_dst}\n"
- ],
- "metadata": {
- "id": "LUOG7BzQVLKp",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Mount to Google Drive\n",
- "mount_drive= True #@param {'type':'boolean'}\n",
- "\n",
- "if mount_drive== True:\n",
- " from google.colab import drive\n",
- " drive.mount('/content/drive')"
- ],
- "metadata": {
- "id": "OuRqOSp2eU6t",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "#Huggingface_hub Integration"
- ],
- "metadata": {
- "id": "QtVP2le8PL2T"
- }
+ "# Add a comment to explain what the code is doing\n",
+ "# Run the pruning script\n",
+ "!python3 prune.py {source_model_path} {pruned_model_path}"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "##Instruction:\n",
- "0. Of course you need a Huggingface Account first\n",
- "1. Create huggingface token, go to `Profile > Access Tokens > New Token > Create a new access token` with the `Write` role.\n",
- "2. All cells below are checked `opt-out` by default so you need to uncheck it if you want to running the cells."
- ],
- "metadata": {
- "id": "tbKgmh_AO5NG"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Login to Huggingface hub\n",
- "#@markdown Opt-out this cell when run all\n",
- "%cd /content/kohya-trainer\n",
- "from IPython.core.display import HTML\n",
- "\n",
- "opt_out= False #@param {'type':'boolean'}\n",
- "\n",
- "#@markdown Prepare your Huggingface token\n",
- "\n",
- "saved_token= \"save-your-write-token-here\" #@param {'type': 'string'}\n",
- "\n",
- "if opt_out == False:\n",
- " !pip install huggingface_hub\n",
- " \n",
- " from huggingface_hub import login\n",
- " login()\n",
- "\n"
- ],
"metadata": {
- "id": "Da7awoqAPJ3a",
- "cellView": "form"
+ "id": "jypUkLWc48R_"
},
- "execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
"source": [
- "##Commit trained model to Huggingface"
- ],
- "metadata": {
- "id": "jypUkLWc48R_"
- }
+ "## Commit trained model to Huggingface"
+ ]
},
{
"cell_type": "markdown",
- "source": [
- "###Instruction:\n",
- "0. Create huggingface repository for model\n",
- "1. Clone your model to this colab session\n",
- "2. Move these necessary file to your repository to save your trained model to huggingface\n",
- "\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- File `epoch-nnnnn.ckpt` and/or\n",
- "- File `last.ckpt`, \n",
- "\n",
- "4. Commit your model to huggingface"
- ],
"metadata": {
"id": "TvZgRSmKVSRw"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Model\n",
- "\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " Repository_url = \"https://huggingface.co/Linaqruf/experimental\" #@param {'type': 'string'}\n",
- " !git clone {Repository_url}\n",
- "else:\n",
- " pass\n"
- ],
- "metadata": {
- "id": "182Law9oUiYN",
- "cellView": "form"
},
- "execution_count": null,
- "outputs": []
+ "source": [
+ "### To Commit models:\n",
+ "1. Create a huggingface repository for your model.\n",
+ "2. Clone your model to this Colab session.\n",
+ "3. Move the necessary files to your repository to save your trained model to huggingface. These files are located in `fine-tuned` folder:\n",
+ " - `epoch-nnnnn.ckpt` and/or\n",
+ " - `last.ckpt`\n",
+ "4. Commit your model to huggingface.\n",
+ "\n",
+ "### To Commit datasets:\n",
+ "1. Create a huggingface repository for your datasets.\n",
+ "2. Clone your datasets to this Colab session.\n",
+ "3. Move the necessary files to your repository so that you can resume training without rebuilding your dataset with this notebook:\n",
+ " - The `train_data` folder.\n",
+ " - The `meta_lat.json` file.\n",
+ " - The `last-state` folder.\n",
+ "4. Commit your datasets to huggingface.\n",
+ "\n"
+ ]
},
{
"cell_type": "code",
- "source": [
- "#@title Commit to Huggingface\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
- "\n",
- "if opt_out == False:\n",
- " %cd /content\n",
- " #@markdown Go to your model path\n",
- " model_path= \"alphanime-diffusion\" #@param {'type': 'string'}\n",
- "\n",
- " #@markdown Your path look like /content/**model_path**\n",
- " #@markdown ___\n",
- " #@markdown #Git Commit\n",
- "\n",
- " #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-name\" #@param {'type': 'string'}\n",
- " #@markdown Set **commit message**\n",
- " commit_m= \"this is commit message\" #@param {'type': 'string'}\n",
- "\n",
- " %cd \"/content/{model_path}\"\n",
- " !git lfs install\n",
- " !huggingface-cli lfs-enable-largefiles .\n",
- " !git add .\n",
- " !git lfs help smudge\n",
- " !git config --global user.email \"{email}\"\n",
- " !git config --global user.name \"{name}\"\n",
- " !git commit -m \"{commit_m}\"\n",
- " !git push\n",
- "\n",
- "else:\n",
- " pass"
- ],
- "metadata": {
- "id": "87wG7QIZbtZE",
- "cellView": "form"
- },
"execution_count": null,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "source": [
- "##Commit dataset to huggingface"
- ],
"metadata": {
- "id": "olP2yaK3OKcr"
- }
- },
- {
- "cell_type": "markdown",
+ "cellView": "form",
+ "id": "182Law9oUiYN"
+ },
+ "outputs": [],
"source": [
- "###Instruction:\n",
- "0. Create huggingface repository for datasets\n",
- "1. Clone your datasets to this colab session\n",
- "2. Move these necessary file to your repository so that you can do resume training next time without rebuild your dataset with this notebook\n",
+ "#@title Clone Model or Datasets\n",
"\n",
- ">in `content/kohya-trainer`\n",
- "- Folder `train_data`\n",
- "- File `meta_cap_dd.json`\n",
- "- File `meta_lat.json`\n",
+ "#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- ">in `content/kohya-trainer/fine-tuned`\n",
- "- Folder `last-state`\n",
+ "#@markdown Type of item to clone (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
- "4. Commit your datasets to huggingface"
- ],
- "metadata": {
- "id": "jiSb0z2CVtc_"
- }
- },
- {
- "cell_type": "code",
- "source": [
- "#@title Clone Dataset\n",
- "#@markdown Opt-out this cell when run all\n",
- "opt_out= True #@param {'type':'boolean'}\n",
+ "#@markdown Install or uninstall git lfs\n",
+ "install_git_lfs = False #@param {'type':'boolean'}\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " Repository_url = \"https://huggingface.co/datasets/Linaqruf/alphanime-diffusion-tag\" #@param {'type': 'string'}\n",
+ " username = \"your-huggingface-username\" #@param {'type': 'string'}\n",
+ " model_repo = \"your-huggingface-model-repo\" #@param {'type': 'string'}\n",
+ " datasets_repo = \"your-huggingface-datasets-repo\" #@param {'type': 'string'}\n",
+ " \n",
+ " if type_of_item == \"model\":\n",
+ " Repository_url = f\"https://huggingface.co/{username}/{model_repo}\"\n",
+ " elif type_of_item == \"dataset\":\n",
+ " Repository_url = f\"https://huggingface.co/datasets/{username}/{datasets_repo}\"\n",
+ "\n",
+ " if install_git_lfs:\n",
+ " !git lfs install\n",
+ " else:\n",
+ " !git lfs uninstall\n",
+ "\n",
" !git clone {Repository_url}\n",
"else:\n",
" pass\n"
- ],
- "metadata": {
- "id": "QhL6UgqDOURK",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "87wG7QIZbtZE"
+ },
+ "outputs": [],
"source": [
- "#@title Commit to Huggingface\n",
+ "#@title Commit Model or Datasets to Huggingface\n",
+ "\n",
"#@markdown Opt-out this cell when run all\n",
+ "opt_out = True #@param {'type':'boolean'}\n",
"\n",
- "opt_out= True #@param {'type':'boolean'}\n",
+ "#@markdown Type of item to commit (model or dataset)\n",
+ "type_of_item = \"model\" #@param [\"model\", \"dataset\"]\n",
"\n",
"if opt_out == False:\n",
" %cd /content\n",
- " #@markdown Go to your model path\n",
- " dataset_path= \"alphanime-diffusion-tag\" #@param {'type': 'string'}\n",
+ " #@markdown Go to your model or dataset path\n",
+ " item_path = \"your-cloned-model-or-datasets-repo\" #@param {'type': 'string'}\n",
"\n",
- " #@markdown Your path look like /content/**dataset_path**\n",
- " #@markdown ___\n",
" #@markdown #Git Commit\n",
"\n",
" #@markdown Set **git commit identity**\n",
- "\n",
- " email= \"your-email\" #@param {'type': 'string'}\n",
- " name= \"your-name\" #@param {'type': 'string'}\n",
+ " email = \"your-email\" #@param {'type': 'string'}\n",
+ " name = \"your-username\" #@param {'type': 'string'}\n",
" #@markdown Set **commit message**\n",
- " commit_m= \"this is commit message\" #@param {'type': 'string'}\n",
+ " commit_m = \"feat: upload 6 epochs model\" #@param {'type': 'string'}\n",
"\n",
- " %cd \"/content/{dataset_path}\"\n",
+ " %cd {item_path}\n",
" !git lfs install\n",
" !huggingface-cli lfs-enable-largefiles .\n",
" !git add .\n",
@@ -909,13 +937,24 @@
"\n",
"else:\n",
" pass"
- ],
- "metadata": {
- "id": "abHLg4I0Os5T",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
}
- ]
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "provenance": [],
+ "include_colab_link": true
+ },
+ "gpuClass": "standard",
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
}
\ No newline at end of file
diff --git a/script/detect_face_rotate.py b/script/detect_face_rotate_v2.py
similarity index 82%
rename from script/detect_face_rotate.py
rename to script/detect_face_rotate_v2.py
index 04c818c1..aba83b4d 100644
--- a/script/detect_face_rotate.py
+++ b/script/detect_face_rotate_v2.py
@@ -3,6 +3,8 @@
# 横長の画像から顔検出して正立するように回転し、そこを中心に正方形に切り出す
+# v2: extract max face if multiple faces are found
+
import argparse
import math
import cv2
@@ -15,24 +17,38 @@
KP_REYE = 11
KP_LEYE = 19
+SCORE_THRES = 0.90
+
def detect_face(detector, image):
preds = detector(image) # bgr
-
+ # print(len(preds))
if len(preds) == 0:
- return None, None
-
- left = preds[0]['bbox'][0]
- top = preds[0]['bbox'][1]
- right = preds[0]['bbox'][2]
- bottom = preds[0]['bbox'][3]
+ return None, None, None, None, None
+
+ index = -1
+ max_score = 0
+ max_size = 0
+ for i in range(len(preds)):
+ bb = preds[i]['bbox']
+ score = bb[-1]
+ size = max(bb[2]-bb[0], bb[3]-bb[1])
+ if (score > max_score and max_score < SCORE_THRES) or (score >= SCORE_THRES and size > max_size):
+ index = i
+ max_score = score
+ max_size = size
+
+ left = preds[index]['bbox'][0]
+ top = preds[index]['bbox'][1]
+ right = preds[index]['bbox'][2]
+ bottom = preds[index]['bbox'][3]
cx = int((left + right) / 2)
cy = int((top + bottom) / 2)
fw = int(right - left)
fh = int(bottom - top)
- lex, ley = preds[0]['keypoints'][KP_LEYE, 0:2]
- rex, rey = preds[0]['keypoints'][KP_REYE, 0:2]
+ lex, ley = preds[index]['keypoints'][KP_LEYE, 0:2]
+ rex, rey = preds[index]['keypoints'][KP_REYE, 0:2]
angle = math.atan2(ley - rey, lex - rex)
angle = angle / math.pi * 180
return cx, cy, fw, fh, angle
@@ -81,7 +97,7 @@ def process(args):
output_extension = ".png"
os.makedirs(args.dst_dir, exist_ok=True)
- paths = glob.glob(args.src_dir + "/*.png") + glob.glob(args.src_dir + "/*.jpg")
+ paths = glob.glob(os.path.join(args.src_dir, "*.png")) + glob.glob(os.path.join(args.src_dir, "*.jpg"))
for path in tqdm(paths):
basename = os.path.splitext(os.path.basename(path))[0]
@@ -97,8 +113,9 @@ def process(args):
cx, cy, fw, fh, angle = detect_face(detector, image)
if cx is None:
- print(f"face not found: {path}")
- cx = cy = fw = fh = 0
+ print(f"face not found, skip: {path}")
+ # cx = cy = fw = fh = 0
+ continue # スキップする
# オプション指定があれば回転する
if args.rotate and cx != 0:
@@ -114,10 +131,12 @@ def process(args):
# 顔サイズを基準にリサイズする
scale = args.resize_face_size / max(fw, fh)
if scale < crop_width / w:
- print(f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
+ print(
+ f"image width too small in face size based resizing / 顔を基準にリサイズすると画像の幅がcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
scale = crop_width / w
if scale < crop_height / h:
- print(f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
+ print(
+ f"image height too small in face size based resizing / 顔を基準にリサイズすると画像の高さがcrop sizeより小さい(顔が相対的に大きすぎる)ので顔サイズが変わります: {path}")
scale = crop_height / h
else:
if w < crop_width:
diff --git a/script/requirements.txt b/script/requirements.txt
index a71d3f18..dec1ac7c 100644
--- a/script/requirements.txt
+++ b/script/requirements.txt
@@ -6,6 +6,6 @@ opencv-python
einops
diffusers[torch]==0.9.0
pytorch_lightning
-bitsandbytes==0.35.0
+bitsandbytes
tensorboard
safetensors==0.2.5
diff --git a/train_db_fixed/model_util.py b/train_db_fixed/model_util.py
new file mode 100644
index 00000000..74650bf4
--- /dev/null
+++ b/train_db_fixed/model_util.py
@@ -0,0 +1,1166 @@
+# v1: split from train_db_fixed.py.
+# v2: support safetensors
+
+import math
+import os
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
+from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from safetensors.torch import load_file, save_file
+
+# DiffUsers版StableDiffusionのモデルパラメータ
+NUM_TRAIN_TIMESTEPS = 1000
+BETA_START = 0.00085
+BETA_END = 0.0120
+
+UNET_PARAMS_MODEL_CHANNELS = 320
+UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
+UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
+UNET_PARAMS_IMAGE_SIZE = 32 # unused
+UNET_PARAMS_IN_CHANNELS = 4
+UNET_PARAMS_OUT_CHANNELS = 4
+UNET_PARAMS_NUM_RES_BLOCKS = 2
+UNET_PARAMS_CONTEXT_DIM = 768
+UNET_PARAMS_NUM_HEADS = 8
+
+VAE_PARAMS_Z_CHANNELS = 4
+VAE_PARAMS_RESOLUTION = 256
+VAE_PARAMS_IN_CHANNELS = 3
+VAE_PARAMS_OUT_CH = 3
+VAE_PARAMS_CH = 128
+VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
+VAE_PARAMS_NUM_RES_BLOCKS = 2
+
+# V2
+V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
+V2_UNET_PARAMS_CONTEXT_DIM = 1024
+
+
+# region StableDiffusion->Diffusersの変換コード
+# convert_original_stable_diffusion_to_diffusers をコピーして修正している(ASL 2.0)
+
+
+def shave_segments(path, n_shave_prefix_segments=1):
+ """
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
+ """
+ if n_shave_prefix_segments >= 0:
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
+ else:
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
+
+
+def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item.replace("in_layers.0", "norm1")
+ new_item = new_item.replace("in_layers.2", "conv1")
+
+ new_item = new_item.replace("out_layers.0", "norm2")
+ new_item = new_item.replace("out_layers.3", "conv2")
+
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside resnets to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
+
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
+
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
+ """
+ Updates paths inside attentions to the new naming scheme (local renaming)
+ """
+ mapping = []
+ for old_item in old_list:
+ new_item = old_item
+
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
+
+ new_item = new_item.replace("q.weight", "query.weight")
+ new_item = new_item.replace("q.bias", "query.bias")
+
+ new_item = new_item.replace("k.weight", "key.weight")
+ new_item = new_item.replace("k.bias", "key.bias")
+
+ new_item = new_item.replace("v.weight", "value.weight")
+ new_item = new_item.replace("v.bias", "value.bias")
+
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
+
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
+
+ mapping.append({"old": old_item, "new": new_item})
+
+ return mapping
+
+
+def assign_to_checkpoint(
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
+):
+ """
+ This does the final conversion step: take locally converted weights and apply a global renaming
+ to them. It splits attention layers, and takes into account additional replacements
+ that may arise.
+
+ Assigns the weights to the new checkpoint.
+ """
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
+
+ # Splits the attention layers into three variables.
+ if attention_paths_to_split is not None:
+ for path, path_map in attention_paths_to_split.items():
+ old_tensor = old_checkpoint[path]
+ channels = old_tensor.shape[0] // 3
+
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
+
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
+
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
+
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
+
+ for path in paths:
+ new_path = path["new"]
+
+ # These have already been assigned
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
+ continue
+
+ # Global renaming happens here
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
+
+ if additional_replacements is not None:
+ for replacement in additional_replacements:
+ new_path = new_path.replace(replacement["old"], replacement["new"])
+
+ # proj_attn.weight has to be converted from conv 1D to linear
+ if "proj_attn.weight" in new_path:
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
+ else:
+ checkpoint[new_path] = old_checkpoint[path["old"]]
+
+
+def conv_attn_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in attn_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+ elif "proj_attn.weight" in key:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0]
+
+
+def linear_transformer_to_conv(checkpoint):
+ keys = list(checkpoint.keys())
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in tf_keys:
+ if checkpoint[key].ndim == 2:
+ checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
+
+
+def convert_ldm_unet_checkpoint(v2, checkpoint, config):
+ """
+ Takes a state dict and a config, and returns a converted checkpoint.
+ """
+
+ # extract state_dict for UNet
+ unet_state_dict = {}
+ unet_key = "model.diffusion_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ if key.startswith(unet_key):
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
+
+ new_checkpoint = {}
+
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
+
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
+
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
+
+ # Retrieves the keys for the input blocks only
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
+ input_blocks = {
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key]
+ for layer_id in range(num_input_blocks)
+ }
+
+ # Retrieves the keys for the middle blocks only
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
+ middle_blocks = {
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key]
+ for layer_id in range(num_middle_blocks)
+ }
+
+ # Retrieves the keys for the output blocks only
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
+ output_blocks = {
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key]
+ for layer_id in range(num_output_blocks)
+ }
+
+ for i in range(1, num_input_blocks):
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
+
+ resnets = [
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
+ ]
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
+
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.weight"
+ )
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
+ f"input_blocks.{i}.0.op.bias"
+ )
+
+ paths = renew_resnet_paths(resnets)
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ resnet_0 = middle_blocks[0]
+ attentions = middle_blocks[1]
+ resnet_1 = middle_blocks[2]
+
+ resnet_0_paths = renew_resnet_paths(resnet_0)
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
+
+ resnet_1_paths = renew_resnet_paths(resnet_1)
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
+
+ attentions_paths = renew_attention_paths(attentions)
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ for i in range(num_output_blocks):
+ block_id = i // (config["layers_per_block"] + 1)
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
+ output_block_list = {}
+
+ for layer in output_block_layers:
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
+ if layer_id in output_block_list:
+ output_block_list[layer_id].append(layer_name)
+ else:
+ output_block_list[layer_id] = [layer_name]
+
+ if len(output_block_list) > 1:
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
+
+ resnet_0_paths = renew_resnet_paths(resnets)
+ paths = renew_resnet_paths(resnets)
+
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+
+ # オリジナル:
+ # if ["conv.weight", "conv.bias"] in output_block_list.values():
+ # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
+
+ # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが
+ for l in output_block_list.values():
+ l.sort()
+
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.bias"
+ ]
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
+ f"output_blocks.{i}.{index}.conv.weight"
+ ]
+
+ # Clear attentions as they have been attributed above.
+ if len(attentions) == 2:
+ attentions = []
+
+ if len(attentions):
+ paths = renew_attention_paths(attentions)
+ meta_path = {
+ "old": f"output_blocks.{i}.1",
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
+ }
+ assign_to_checkpoint(
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
+ )
+ else:
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
+ for path in resnet_0_paths:
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
+
+ new_checkpoint[new_path] = unet_state_dict[old_path]
+
+ # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
+ if v2:
+ linear_transformer_to_conv(new_checkpoint)
+
+ return new_checkpoint
+
+
+def convert_ldm_vae_checkpoint(checkpoint, config):
+ # extract state dict for VAE
+ vae_state_dict = {}
+ vae_key = "first_stage_model."
+ keys = list(checkpoint.keys())
+ for key in keys:
+ if key.startswith(vae_key):
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
+ # if len(vae_state_dict) == 0:
+ # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict
+ # vae_state_dict = checkpoint
+
+ new_checkpoint = {}
+
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
+
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
+
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
+
+ # Retrieves the keys for the encoder down blocks only
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
+ down_blocks = {
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
+ }
+
+ # Retrieves the keys for the decoder up blocks only
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
+ up_blocks = {
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
+ }
+
+ for i in range(num_down_blocks):
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
+
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.weight"
+ )
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
+ f"encoder.down.{i}.downsample.conv.bias"
+ )
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+
+ for i in range(num_up_blocks):
+ block_id = num_up_blocks - 1 - i
+ resnets = [
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
+ ]
+
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.weight"
+ ]
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
+ f"decoder.up.{block_id}.upsample.conv.bias"
+ ]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
+ num_mid_res_blocks = 2
+ for i in range(1, num_mid_res_blocks + 1):
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
+
+ paths = renew_vae_resnet_paths(resnets)
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
+ paths = renew_vae_attention_paths(mid_attentions)
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
+ conv_attn_to_linear(new_checkpoint)
+ return new_checkpoint
+
+
+def create_unet_diffusers_config(v2):
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ # unet_params = original_config.model.params.unet_config.params
+
+ block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
+
+ down_block_types = []
+ resolution = 1
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
+ down_block_types.append(block_type)
+ if i != len(block_out_channels) - 1:
+ resolution *= 2
+
+ up_block_types = []
+ for i in range(len(block_out_channels)):
+ block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
+ up_block_types.append(block_type)
+ resolution //= 2
+
+ config = dict(
+ sample_size=UNET_PARAMS_IMAGE_SIZE,
+ in_channels=UNET_PARAMS_IN_CHANNELS,
+ out_channels=UNET_PARAMS_OUT_CHANNELS,
+ down_block_types=tuple(down_block_types),
+ up_block_types=tuple(up_block_types),
+ block_out_channels=tuple(block_out_channels),
+ layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
+ cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
+ attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
+ )
+
+ return config
+
+
+def create_vae_diffusers_config():
+ """
+ Creates a config for the diffusers based on the config of the LDM model.
+ """
+ # vae_params = original_config.model.params.first_stage_config.params.ddconfig
+ # _ = original_config.model.params.first_stage_config.params.embed_dim
+ block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
+
+ config = dict(
+ sample_size=VAE_PARAMS_RESOLUTION,
+ in_channels=VAE_PARAMS_IN_CHANNELS,
+ out_channels=VAE_PARAMS_OUT_CH,
+ down_block_types=tuple(down_block_types),
+ up_block_types=tuple(up_block_types),
+ block_out_channels=tuple(block_out_channels),
+ latent_channels=VAE_PARAMS_Z_CHANNELS,
+ layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
+ )
+ return config
+
+
+def convert_ldm_clip_checkpoint_v1(checkpoint):
+ keys = list(checkpoint.keys())
+ text_model_dict = {}
+ for key in keys:
+ if key.startswith("cond_stage_model.transformer"):
+ text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
+ return text_model_dict
+
+
+def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
+ # 嫌になるくらい違うぞ!
+ def convert_key(key):
+ if not key.startswith("cond_stage_model"):
+ return None
+
+ # common conversion
+ key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
+ key = key.replace("cond_stage_model.model.", "text_model.")
+
+ if "resblocks" in key:
+ # resblocks conversion
+ key = key.replace(".resblocks.", ".layers.")
+ if ".ln_" in key:
+ key = key.replace(".ln_", ".layer_norm")
+ elif ".mlp." in key:
+ key = key.replace(".c_fc.", ".fc1.")
+ key = key.replace(".c_proj.", ".fc2.")
+ elif '.attn.out_proj' in key:
+ key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
+ elif '.attn.in_proj' in key:
+ key = None # 特殊なので後で処理する
+ else:
+ raise ValueError(f"unexpected key in SD: {key}")
+ elif '.positional_embedding' in key:
+ key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
+ elif '.text_projection' in key:
+ key = None # 使われない???
+ elif '.logit_scale' in key:
+ key = None # 使われない???
+ elif '.token_embedding' in key:
+ key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
+ elif '.ln_final' in key:
+ key = key.replace(".ln_final", ".final_layer_norm")
+ return key
+
+ keys = list(checkpoint.keys())
+ new_sd = {}
+ for key in keys:
+ # remove resblocks 23
+ if '.resblocks.23.' in key:
+ continue
+ new_key = convert_key(key)
+ if new_key is None:
+ continue
+ new_sd[new_key] = checkpoint[key]
+
+ # attnの変換
+ for key in keys:
+ if '.resblocks.23.' in key:
+ continue
+ if '.resblocks' in key and '.attn.in_proj_' in key:
+ # 三つに分割
+ values = torch.chunk(checkpoint[key], 3)
+
+ key_suffix = ".weight" if "weight" in key else ".bias"
+ key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
+ key_pfx = key_pfx.replace("_weight", "")
+ key_pfx = key_pfx.replace("_bias", "")
+ key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
+ new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
+ new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
+ new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
+
+ # position_idsの追加
+ new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
+ return new_sd
+
+# endregion
+
+
+# region Diffusers->StableDiffusion の変換コード
+# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0)
+
+def conv_transformer_to_linear(checkpoint):
+ keys = list(checkpoint.keys())
+ tf_keys = ["proj_in.weight", "proj_out.weight"]
+ for key in keys:
+ if ".".join(key.split(".")[-2:]) in tf_keys:
+ if checkpoint[key].ndim > 2:
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
+
+
+def convert_unet_state_dict_to_sd(v2, unet_state_dict):
+ unet_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("time_embed.0.weight", "time_embedding.linear_1.weight"),
+ ("time_embed.0.bias", "time_embedding.linear_1.bias"),
+ ("time_embed.2.weight", "time_embedding.linear_2.weight"),
+ ("time_embed.2.bias", "time_embedding.linear_2.bias"),
+ ("input_blocks.0.0.weight", "conv_in.weight"),
+ ("input_blocks.0.0.bias", "conv_in.bias"),
+ ("out.0.weight", "conv_norm_out.weight"),
+ ("out.0.bias", "conv_norm_out.bias"),
+ ("out.2.weight", "conv_out.weight"),
+ ("out.2.bias", "conv_out.bias"),
+ ]
+
+ unet_conversion_map_resnet = [
+ # (stable-diffusion, HF Diffusers)
+ ("in_layers.0", "norm1"),
+ ("in_layers.2", "conv1"),
+ ("out_layers.0", "norm2"),
+ ("out_layers.3", "conv2"),
+ ("emb_layers.1", "time_emb_proj"),
+ ("skip_connection", "conv_shortcut"),
+ ]
+
+ unet_conversion_map_layer = []
+ for i in range(4):
+ # loop over downblocks/upblocks
+
+ for j in range(2):
+ # loop over resnets/attentions for downblocks
+ hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
+ sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
+ unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
+
+ if i < 3:
+ # no attention layers in down_blocks.3
+ hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
+ sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
+ unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
+
+ for j in range(3):
+ # loop over resnets/attentions for upblocks
+ hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
+ sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
+ unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
+
+ if i > 0:
+ # no attention layers in up_blocks.0
+ hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
+ sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
+ unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
+
+ if i < 3:
+ # no downsample in down_blocks.3
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
+ sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
+ unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ # no upsample in up_blocks.3
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
+ unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ hf_mid_atn_prefix = "mid_block.attentions.0."
+ sd_mid_atn_prefix = "middle_block.1."
+ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
+
+ for j in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{j}."
+ sd_mid_res_prefix = f"middle_block.{2*j}."
+ unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ # buyer beware: this is a *brittle* function,
+ # and correct output requires that all of these pieces interact in
+ # the exact order in which I have arranged them.
+ mapping = {k: k for k in unet_state_dict.keys()}
+ for sd_name, hf_name in unet_conversion_map:
+ mapping[hf_name] = sd_name
+ for k, v in mapping.items():
+ if "resnets" in k:
+ for sd_part, hf_part in unet_conversion_map_resnet:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ for sd_part, hf_part in unet_conversion_map_layer:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
+
+ if v2:
+ conv_transformer_to_linear(new_state_dict)
+
+ return new_state_dict
+
+
+# ================#
+# VAE Conversion #
+# ================#
+
+def reshape_weight_for_sd(w):
+ # convert HF linear weights to SD conv2d weights
+ return w.reshape(*w.shape, 1, 1)
+
+
+def convert_vae_state_dict(vae_state_dict):
+ vae_conversion_map = [
+ # (stable-diffusion, HF Diffusers)
+ ("nin_shortcut", "conv_shortcut"),
+ ("norm_out", "conv_norm_out"),
+ ("mid.attn_1.", "mid_block.attentions.0."),
+ ]
+
+ for i in range(4):
+ # down_blocks have two resnets
+ for j in range(2):
+ hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
+ sd_down_prefix = f"encoder.down.{i}.block.{j}."
+ vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
+
+ if i < 3:
+ hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
+ sd_downsample_prefix = f"down.{i}.downsample."
+ vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
+
+ hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
+ sd_upsample_prefix = f"up.{3-i}.upsample."
+ vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
+
+ # up_blocks have three resnets
+ # also, up blocks in hf are numbered in reverse from sd
+ for j in range(3):
+ hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
+ sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
+ vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
+
+ # this part accounts for mid blocks in both the encoder and the decoder
+ for i in range(2):
+ hf_mid_res_prefix = f"mid_block.resnets.{i}."
+ sd_mid_res_prefix = f"mid.block_{i+1}."
+ vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
+
+ vae_conversion_map_attn = [
+ # (stable-diffusion, HF Diffusers)
+ ("norm.", "group_norm."),
+ ("q.", "query."),
+ ("k.", "key."),
+ ("v.", "value."),
+ ("proj_out.", "proj_attn."),
+ ]
+
+ mapping = {k: k for k in vae_state_dict.keys()}
+ for k, v in mapping.items():
+ for sd_part, hf_part in vae_conversion_map:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ for k, v in mapping.items():
+ if "attentions" in k:
+ for sd_part, hf_part in vae_conversion_map_attn:
+ v = v.replace(hf_part, sd_part)
+ mapping[k] = v
+ new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
+ weights_to_convert = ["q", "k", "v", "proj_out"]
+ for k, v in new_state_dict.items():
+ for weight_name in weights_to_convert:
+ if f"mid.attn_1.{weight_name}.weight" in k:
+ # print(f"Reshaping {k} for SD format")
+ new_state_dict[k] = reshape_weight_for_sd(v)
+
+ return new_state_dict
+
+
+# endregion
+
+# region 自作のモデル読み書き
+
+def is_safetensors(path):
+ return os.path.splitext(path)[1].lower() == '.safetensors'
+
+
+def load_checkpoint_with_text_encoder_conversion(ckpt_path):
+ # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
+ TEXT_ENCODER_KEY_REPLACEMENTS = [
+ ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
+ ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
+ ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
+ ]
+
+ if is_safetensors(ckpt_path):
+ checkpoint = None
+ state_dict = load_file(ckpt_path, "cpu")
+ else:
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
+ if "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ state_dict = checkpoint
+ checkpoint = None
+
+ key_reps = []
+ for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
+ for key in state_dict.keys():
+ if key.startswith(rep_from):
+ new_key = rep_to + key[len(rep_from):]
+ key_reps.append((key, new_key))
+
+ for key, new_key in key_reps:
+ state_dict[new_key] = state_dict[key]
+ del state_dict[key]
+
+ return checkpoint, state_dict
+
+
+# TODO dtype指定の動作が怪しいので確認する text_encoderを指定形式で作れるか未確認
+def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
+ _, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
+ if dtype is not None:
+ for k, v in state_dict.items():
+ if type(v) is torch.Tensor:
+ state_dict[k] = v.to(dtype)
+
+ # Convert the UNet2DConditionModel model.
+ unet_config = create_unet_diffusers_config(v2)
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
+
+ unet = UNet2DConditionModel(**unet_config)
+ info = unet.load_state_dict(converted_unet_checkpoint)
+ print("loading u-net:", info)
+
+ # Convert the VAE model.
+ vae_config = create_vae_diffusers_config()
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ info = vae.load_state_dict(converted_vae_checkpoint)
+ print("loadint vae:", info)
+
+ # convert text_model
+ if v2:
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
+ cfg = CLIPTextConfig(
+ vocab_size=49408,
+ hidden_size=1024,
+ intermediate_size=4096,
+ num_hidden_layers=23,
+ num_attention_heads=16,
+ max_position_embeddings=77,
+ hidden_act="gelu",
+ layer_norm_eps=1e-05,
+ dropout=0.0,
+ attention_dropout=0.0,
+ initializer_range=0.02,
+ initializer_factor=1.0,
+ pad_token_id=1,
+ bos_token_id=0,
+ eos_token_id=2,
+ model_type="clip_text_model",
+ projection_dim=512,
+ torch_dtype="float32",
+ transformers_version="4.25.0.dev0",
+ )
+ text_model = CLIPTextModel._from_config(cfg)
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+ else:
+ converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
+ info = text_model.load_state_dict(converted_text_encoder_checkpoint)
+ print("loading text encoder:", info)
+
+ return text_model, vae, unet
+
+
+def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False):
+ def convert_key(key):
+ # position_idsの除去
+ if ".position_ids" in key:
+ return None
+
+ # common
+ key = key.replace("text_model.encoder.", "transformer.")
+ key = key.replace("text_model.", "")
+ if "layers" in key:
+ # resblocks conversion
+ key = key.replace(".layers.", ".resblocks.")
+ if ".layer_norm" in key:
+ key = key.replace(".layer_norm", ".ln_")
+ elif ".mlp." in key:
+ key = key.replace(".fc1.", ".c_fc.")
+ key = key.replace(".fc2.", ".c_proj.")
+ elif '.self_attn.out_proj' in key:
+ key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
+ elif '.self_attn.' in key:
+ key = None # 特殊なので後で処理する
+ else:
+ raise ValueError(f"unexpected key in DiffUsers model: {key}")
+ elif '.position_embedding' in key:
+ key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
+ elif '.token_embedding' in key:
+ key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
+ elif 'final_layer_norm' in key:
+ key = key.replace("final_layer_norm", "ln_final")
+ return key
+
+ keys = list(checkpoint.keys())
+ new_sd = {}
+ for key in keys:
+ new_key = convert_key(key)
+ if new_key is None:
+ continue
+ new_sd[new_key] = checkpoint[key]
+
+ # attnの変換
+ for key in keys:
+ if 'layers' in key and 'q_proj' in key:
+ # 三つを結合
+ key_q = key
+ key_k = key.replace("q_proj", "k_proj")
+ key_v = key.replace("q_proj", "v_proj")
+
+ value_q = checkpoint[key_q]
+ value_k = checkpoint[key_k]
+ value_v = checkpoint[key_v]
+ value = torch.cat([value_q, value_k, value_v])
+
+ new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
+ new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
+ new_sd[new_key] = value
+
+ # 最後の層などを捏造するか
+ if make_dummy_weights:
+ print("make dummy weights for resblock.23, text_projection and logit scale.")
+ keys = list(new_sd.keys())
+ for key in keys:
+ if key.startswith("transformer.resblocks.22."):
+ new_sd[key.replace(".22.", ".23.")] = new_sd[key]
+
+ # Diffusersに含まれない重みを作っておく
+ new_sd['text_projection'] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device)
+ new_sd['logit_scale'] = torch.tensor(1)
+
+ return new_sd
+
+
+def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None, vae=None):
+ if ckpt_path is not None:
+ # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む
+ checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path)
+ if checkpoint is None: # safetensors または state_dictのckpt
+ checkpoint = {}
+ strict = False
+ else:
+ strict = True
+ if "state_dict" in state_dict:
+ del state_dict["state_dict"]
+ else:
+ # 新しく作る
+ checkpoint = {}
+ state_dict = {}
+ strict = False
+
+ def update_sd(prefix, sd):
+ for k, v in sd.items():
+ key = prefix + k
+ assert not strict or key in state_dict, f"Illegal key in save SD: {key}"
+ if save_dtype is not None:
+ v = v.detach().clone().to("cpu").to(save_dtype)
+ state_dict[key] = v
+
+ # Convert the UNet model
+ unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
+ update_sd("model.diffusion_model.", unet_state_dict)
+
+ # Convert the text encoder model
+ if v2:
+ make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる
+ text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy)
+ update_sd("cond_stage_model.model.", text_enc_dict)
+ else:
+ text_enc_dict = text_encoder.state_dict()
+ update_sd("cond_stage_model.transformer.", text_enc_dict)
+
+ # Convert the VAE
+ if vae is not None:
+ vae_dict = convert_vae_state_dict(vae.state_dict())
+ update_sd("first_stage_model.", vae_dict)
+
+ # Put together new checkpoint
+ key_count = len(state_dict.keys())
+ new_ckpt = {'state_dict': state_dict}
+
+ if 'epoch' in checkpoint:
+ epochs += checkpoint['epoch']
+ if 'global_step' in checkpoint:
+ steps += checkpoint['global_step']
+
+ new_ckpt['epoch'] = epochs
+ new_ckpt['global_step'] = steps
+
+ if is_safetensors(output_file):
+ # TODO Tensor以外のdictの値を削除したほうがいいか
+ save_file(state_dict, output_file)
+ else:
+ torch.save(new_ckpt, output_file)
+
+ return key_count
+
+
+def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None):
+ if vae is None:
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
+ pipeline = StableDiffusionPipeline(
+ unet=unet,
+ text_encoder=text_encoder,
+ vae=vae,
+ scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
+ tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
+ safety_checker=None,
+ feature_extractor=None,
+ requires_safety_checker=None,
+ )
+ pipeline.save_pretrained(output_dir)
+
+
+VAE_PREFIX = "first_stage_model."
+
+
+def load_vae(vae_id, dtype):
+ print(f"load VAE: {vae_id}")
+ if os.path.isdir(vae_id) or not os.path.isfile(vae_id):
+ # Diffusers local/remote
+ try:
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype)
+ except EnvironmentError as e:
+ print(f"exception occurs in loading vae: {e}")
+ print("retry with subfolder='vae'")
+ vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype)
+ return vae
+
+ # local
+ vae_config = create_vae_diffusers_config()
+
+ if vae_id.endswith(".bin"):
+ # SD 1.5 VAE on Huggingface
+ vae_sd = torch.load(vae_id, map_location="cpu")
+ converted_vae_checkpoint = vae_sd
+ else:
+ # StableDiffusion
+ vae_model = torch.load(vae_id, map_location="cpu")
+ vae_sd = vae_model['state_dict']
+
+ # vae only or full model
+ full_model = False
+ for vae_key in vae_sd:
+ if vae_key.startswith(VAE_PREFIX):
+ full_model = True
+ break
+ if not full_model:
+ sd = {}
+ for key, value in vae_sd.items():
+ sd[VAE_PREFIX + key] = value
+ vae_sd = sd
+ del sd
+
+ # Convert the VAE model.
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config)
+
+ vae = AutoencoderKL(**vae_config)
+ vae.load_state_dict(converted_vae_checkpoint)
+ return vae
+
+
+def get_epoch_ckpt_name(use_safetensors, epoch):
+ return f"epoch-{epoch:06d}" + (".safetensors" if use_safetensors else ".ckpt")
+
+
+def get_last_ckpt_name(use_safetensors):
+ return f"last" + (".safetensors" if use_safetensors else ".ckpt")
+
+# endregion
+
+
+def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
+ max_width, max_height = max_reso
+ max_area = (max_width // divisible) * (max_height // divisible)
+
+ resos = set()
+
+ size = int(math.sqrt(max_area)) * divisible
+ resos.add((size, size))
+
+ size = min_size
+ while size <= max_size:
+ width = size
+ height = min(max_size, (max_area // (width // divisible)) * divisible)
+ resos.add((width, height))
+ resos.add((height, width))
+
+ # # make additional resos
+ # if width >= height and width - divisible >= min_size:
+ # resos.add((width - divisible, height))
+ # resos.add((height, width - divisible))
+ # if height >= width and height - divisible >= min_size:
+ # resos.add((width, height - divisible))
+ # resos.add((height - divisible, width))
+
+ size += divisible
+
+ resos = list(resos)
+ resos.sort()
+
+ aspect_ratios = [w / h for w, h in resos]
+ return resos, aspect_ratios
+
+
+if __name__ == '__main__':
+ resos, aspect_ratios = make_bucket_resolutions((512, 768))
+ print(len(resos))
+ print(resos)
+ print(aspect_ratios)
+
+ ars = set()
+ for ar in aspect_ratios:
+ if ar in ars:
+ print("error! duplicate ar:", ar)
+ ars.add(ar)
diff --git a/train_db_fixed/train_db_fixed_v13.py b/train_db_fixed/train_db_fixed.py
similarity index 57%
rename from train_db_fixed/train_db_fixed_v13.py
rename to train_db_fixed/train_db_fixed.py
index 36f2b8cb..c19b6dce 100644
--- a/train_db_fixed/train_db_fixed_v13.py
+++ b/train_db_fixed/train_db_fixed.py
@@ -11,7 +11,10 @@
# support save_ever_n_epochs/save_state in DiffUsers model
# fix the issue that prior_loss_weight is applyed to train images
# v12: stop train text encode, tqdm smoothing
+# v13: bug fix
+# v14: refactor to use model_util, add log prefix, support safetensors, support vae loading, keep vae in CPU to save the loaded vae
+import gc
import time
from torch.autograd.function import Function
import argparse
@@ -26,9 +29,9 @@
from torchvision import transforms
from accelerate import Accelerator
from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
+from transformers import CLIPTokenizer
import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
+from diffusers import DDPMScheduler, StableDiffusionPipeline
import albumentations as albu
import numpy as np
from PIL import Image
@@ -36,46 +39,24 @@
from einops import rearrange
from torch import einsum
+import model_util
+
# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
TOKENIZER_PATH = "openai/clip-vit-large-patch14"
V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
+# CLIP_ID_L14_336 = "openai/clip-vit-large-patch14-336"
+
# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-LAST_STATE_NAME = "last-state"
-LAST_DIFFUSERS_DIR_NAME = "last"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
EPOCH_STATE_NAME = "epoch-{:06d}-state"
+LAST_STATE_NAME = "last-state"
+
EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
+LAST_DIFFUSERS_DIR_NAME = "last"
# region dataset
-
-def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
- max_width, max_height = max_reso
- max_area = (max_width // divisible) * (max_height // divisible)
-
- resos = set()
-
- size = int(math.sqrt(max_area)) * divisible
- resos.add((size, size))
-
- size = min_size
- while size <= max_size:
- width = size
- height = min(max_size, (max_area // (width // divisible)) * divisible)
- resos.add((width, height))
- resos.add((height, width))
- size += divisible
-
- resos = list(resos)
- resos.sort()
-
- aspect_ratios = [w / h for w, h in resos]
- return resos, aspect_ratios
-
-
class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
super().__init__()
@@ -149,7 +130,7 @@ def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
# bucketingを用意する
if enable_bucket:
- bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
+ bucket_resos, bucket_aspect_ratios = model_util.make_bucket_resolutions((self.width, self.height), min_size, max_size)
else:
# bucketはひとつだけ、すべての画像は同じ解像度
bucket_resos = [(self.width, self.height)]
@@ -665,928 +646,16 @@ def forward_xformers(self, x, context=None, mask=None):
# endregion
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# DiffUsers版StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# V2
-V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
-V2_UNET_PARAMS_CONTEXT_DIM = 1024
-
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def linear_transformer_to_conv(checkpoint):
- keys = list(checkpoint.keys())
- tf_keys = ["proj_in.weight", "proj_out.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in tf_keys:
- if checkpoint[key].ndim == 2:
- checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
-
-
-def convert_ldm_unet_checkpoint(v2, checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
- if v2:
- linear_transformer_to_conv(new_checkpoint)
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config(v2):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint_v1(checkpoint):
- keys = list(checkpoint.keys())
- text_model_dict = {}
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
- return text_model_dict
-
-
-def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
- # 嫌になるくらい違うぞ!
- def convert_key(key):
- if not key.startswith("cond_stage_model"):
- return None
-
- # common conversion
- key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
- key = key.replace("cond_stage_model.model.", "text_model.")
-
- if "resblocks" in key:
- # resblocks conversion
- key = key.replace(".resblocks.", ".layers.")
- if ".ln_" in key:
- key = key.replace(".ln_", ".layer_norm")
- elif ".mlp." in key:
- key = key.replace(".c_fc.", ".fc1.")
- key = key.replace(".c_proj.", ".fc2.")
- elif '.attn.out_proj' in key:
- key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
- elif '.attn.in_proj' in key:
- key = None # 特殊なので後で処理する
- else:
- raise ValueError(f"unexpected key in SD: {key}")
- elif '.positional_embedding' in key:
- key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
- elif '.text_projection' in key:
- key = None # 使われない???
- elif '.logit_scale' in key:
- key = None # 使われない???
- elif '.token_embedding' in key:
- key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
- elif '.ln_final' in key:
- key = key.replace(".ln_final", ".final_layer_norm")
- return key
-
- keys = list(checkpoint.keys())
- new_sd = {}
- for key in keys:
- # remove resblocks 23
- if '.resblocks.23.' in key:
- continue
- new_key = convert_key(key)
- if new_key is None:
- continue
- new_sd[new_key] = checkpoint[key]
-
- # attnの変換
- for key in keys:
- if '.resblocks.23.' in key:
- continue
- if '.resblocks' in key and '.attn.in_proj_' in key:
- # 三つに分割
- values = torch.chunk(checkpoint[key], 3)
-
- key_suffix = ".weight" if "weight" in key else ".bias"
- key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
- key_pfx = key_pfx.replace("_weight", "")
- key_pfx = key_pfx.replace("_bias", "")
- key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
- new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
- new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
- new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
-
- # position_idsの追加
- new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
- return new_sd
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def conv_transformer_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- tf_keys = ["proj_in.weight", "proj_out.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in tf_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
-
-
-def convert_unet_state_dict_to_sd(v2, unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
-
- if v2:
- conv_transformer_to_linear(new_state_dict)
-
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_text_encoder_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
- checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
- if dtype is not None:
- for k, v in state_dict.items():
- if type(v) is torch.Tensor:
- state_dict[k] = v.to(dtype)
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config(v2)
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- info = unet.load_state_dict(converted_unet_checkpoint)
- print("loading u-net:", info)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- info = vae.load_state_dict(converted_vae_checkpoint)
- print("loadint vae:", info)
-
- # convert text_model
- if v2:
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
- cfg = CLIPTextConfig(
- vocab_size=49408,
- hidden_size=1024,
- intermediate_size=4096,
- num_hidden_layers=23,
- num_attention_heads=16,
- max_position_embeddings=77,
- hidden_act="gelu",
- layer_norm_eps=1e-05,
- dropout=0.0,
- attention_dropout=0.0,
- initializer_range=0.02,
- initializer_factor=1.0,
- pad_token_id=1,
- bos_token_id=0,
- eos_token_id=2,
- model_type="clip_text_model",
- projection_dim=512,
- torch_dtype="float32",
- transformers_version="4.25.0.dev0",
- )
- text_model = CLIPTextModel._from_config(cfg)
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
- else:
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
- print("loading text encoder:", info)
-
- return text_model, vae, unet
-
-
-def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
- def convert_key(key):
- # position_idsの除去
- if ".position_ids" in key:
- return None
-
- # common
- key = key.replace("text_model.encoder.", "transformer.")
- key = key.replace("text_model.", "")
- if "layers" in key:
- # resblocks conversion
- key = key.replace(".layers.", ".resblocks.")
- if ".layer_norm" in key:
- key = key.replace(".layer_norm", ".ln_")
- elif ".mlp." in key:
- key = key.replace(".fc1.", ".c_fc.")
- key = key.replace(".fc2.", ".c_proj.")
- elif '.self_attn.out_proj' in key:
- key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
- elif '.self_attn.' in key:
- key = None # 特殊なので後で処理する
- else:
- raise ValueError(f"unexpected key in DiffUsers model: {key}")
- elif '.position_embedding' in key:
- key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
- elif '.token_embedding' in key:
- key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
- elif 'final_layer_norm' in key:
- key = key.replace("final_layer_norm", "ln_final")
- return key
-
- keys = list(checkpoint.keys())
- new_sd = {}
- for key in keys:
- new_key = convert_key(key)
- if new_key is None:
- continue
- new_sd[new_key] = checkpoint[key]
-
- # attnの変換
- for key in keys:
- if 'layers' in key and 'q_proj' in key:
- # 三つを結合
- key_q = key
- key_k = key.replace("q_proj", "k_proj")
- key_v = key.replace("q_proj", "v_proj")
-
- value_q = checkpoint[key_q]
- value_k = checkpoint[key_k]
- value_v = checkpoint[key_v]
- value = torch.cat([value_q, value_k, value_v])
-
- new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
- new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
- new_sd[new_key] = value
-
- return new_sd
-
-
-def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- def assign_new_sd(prefix, sd):
- for k, v in sd.items():
- key = prefix + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
- assign_new_sd("model.diffusion_model.", unet_state_dict)
-
- # Convert the text encoder model
- if v2:
- text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
- assign_new_sd("cond_stage_model.model.", text_enc_dict)
- else:
- text_enc_dict = text_encoder.state_dict()
- assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-
-
-def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
- pipeline = StableDiffusionPipeline(
- unet=unet,
- text_encoder=text_encoder,
- vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
- scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
- tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
- safety_checker=None,
- feature_extractor=None,
- requires_safety_checker=None,
- )
- pipeline.save_pretrained(output_dir)
-
-# endregion
-
-
def collate_fn(examples):
return examples[0]
+# def load_clip_l14_336(dtype):
+# print(f"loading CLIP: {CLIP_ID_L14_336}")
+# text_encoder = CLIPTextModel.from_pretrained(CLIP_ID_L14_336, torch_dtype=dtype)
+# return text_encoder
+
+
def train(args):
if args.caption_extention is not None:
args.caption_extension = args.caption_extention
@@ -1747,7 +816,8 @@ def load_dreambooth_dir(dir):
logging_dir = None
else:
log_with = "tensorboard"
- logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime('%Y%m%d%H%M%S', time.localtime())
accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
log_with=log_with, logging_dir=logging_dir)
@@ -1769,7 +839,7 @@ def load_dreambooth_dir(dir):
# モデルを読み込む
if use_stable_diffusion_format:
print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
+ text_encoder, vae, unet = model_util.load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
else:
print("load Diffusers pretrained models")
pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
@@ -1779,6 +849,16 @@ def load_dreambooth_dir(dir):
unet = pipe.unet
del pipe
+ # # 置換するCLIPを読み込む
+ # if args.replace_clip_l14_336:
+ # text_encoder = load_clip_l14_336(weight_dtype)
+ # print(f"large clip {CLIP_ID_L14_336} is loaded")
+
+ # VAEを読み込む
+ if args.vae is not None:
+ vae = model_util.load_vae(args.vae, weight_dtype)
+ print("additional VAE loaded")
+
# モデルに xformers とか memory efficient attention を組み込む
replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
@@ -1789,9 +869,10 @@ def load_dreambooth_dir(dir):
vae.eval()
with torch.no_grad():
train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
- del vae
+ vae.to("cpu")
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ gc.collect()
else:
train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
vae.requires_grad_(False)
@@ -1878,7 +959,7 @@ def load_dreambooth_dir(dir):
print(f"epoch {epoch+1}/{num_train_epochs}")
# 指定したステップ数までText Encoderを学習する:epoch最初の状態
- train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training
+ train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training
unet.train()
if train_text_encoder:
text_encoder.train()
@@ -1886,7 +967,7 @@ def load_dreambooth_dir(dir):
loss_total = 0
for step, batch in enumerate(train_dataloader):
# 指定したステップ数でText Encoderの学習を止める
- stop_text_encoder_training = args.stop_text_encoder_training is not None and global_step == args.stop_text_encoder_training
+ stop_text_encoder_training = args.stop_text_encoder_training is not None and global_step == args.stop_text_encoder_training
if stop_text_encoder_training:
print(f"stop text encoder training at step {global_step}")
text_encoder.train(False)
@@ -1999,14 +1080,14 @@ def load_dreambooth_dir(dir):
print("saving checkpoint.")
if use_stable_diffusion_format:
os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
+ ckpt_file = os.path.join(args.output_dir, model_util.get_epoch_ckpt_name(args.use_safetensors, epoch + 1))
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
+ args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype, vae)
else:
out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
os.makedirs(out_dir, exist_ok=True)
- save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
- accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
+ accelerator.unwrap_model(unet), args.pretrained_model_name_or_path)
if args.save_state:
print("saving state.")
@@ -2028,16 +1109,16 @@ def load_dreambooth_dir(dir):
if is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
+ ckpt_file = os.path.join(args.output_dir, model_util.get_last_ckpt_name(args.use_safetensors))
print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
- args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
+ model_util.save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
+ args.pretrained_model_name_or_path, epoch, global_step, save_dtype, vae)
else:
# Create the pipeline using using the trained modules and save it.
print(f"save trained model as Diffusers to {args.output_dir}")
out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
os.makedirs(out_dir, exist_ok=True)
- save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path, save_dtype)
+ model_util.save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path)
print("model saved.")
@@ -2050,6 +1131,8 @@ def load_dreambooth_dir(dir):
help='enable v-parameterization training / v-parameterization学習を有効にする')
parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
+ # parser.add_argument("--replace_clip_l14_336", action='store_true',
+ # help="Replace CLIP (Text Encoder) to l/14@336 / CLIP(Text Encoder)をl/14@336に入れ替える")
parser.add_argument("--fine_tuning", action="store_true",
help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
parser.add_argument("--shuffle_caption", action="store_true",
@@ -2062,9 +1145,9 @@ def load_dreambooth_dir(dir):
parser.add_argument("--dataset_repeats", type=int, default=None,
help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model (default format is same to input) / 学習後のモデル出力先ディレクトリ(デフォルトの保存形式は読み込んだ形式と同じ)")
- # parser.add_argument("--save_as_sd", action='store_true',
- # help="save the model as StableDiffusion checkpoint / 学習後のモデルをStableDiffusionのcheckpointとして保存する")
+ help="directory to output trained model / 学習後のモデル出力先ディレクトリ")
+ parser.add_argument("--use_safetensors", action='store_true',
+ help="use safetensors format for StableDiffusion checkpoint / StableDiffusionのcheckpointをsafetensors形式で保存する")
parser.add_argument("--save_every_n_epochs", type=int, default=None,
help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します")
parser.add_argument("--save_state", action="store_true",
@@ -2073,7 +1156,8 @@ def load_dreambooth_dir(dir):
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
parser.add_argument("--no_token_padding", action="store_true",
help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--stop_text_encoder_training", type=int, default=None, help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数")
+ parser.add_argument("--stop_text_encoder_training", type=int, default=None,
+ help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数")
parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
parser.add_argument("--face_crop_aug_range", type=str, default=None,
@@ -2092,6 +1176,8 @@ def load_dreambooth_dir(dir):
help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
parser.add_argument("--xformers", action="store_true",
help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
+ parser.add_argument("--vae", type=str, default=None,
+ help="path to checkpoint of vae to replace / VAEを入れ替える場合、VAEのcheckpointファイルまたはディレクトリ")
parser.add_argument("--cache_latents", action="store_true",
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
parser.add_argument("--enable_bucket", action="store_true",
@@ -2111,6 +1197,7 @@ def load_dreambooth_dir(dir):
help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
parser.add_argument("--logging_dir", type=str, default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
+ parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument("--lr_scheduler", type=str, default="constant",
help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
parser.add_argument("--lr_warmup_steps", type=int, default=0,
diff --git a/train_db_fixed/train_db_fixed_v10.py b/train_db_fixed/train_db_fixed_v10.py
deleted file mode 100644
index 2037e786..00000000
--- a/train_db_fixed/train_db_fixed_v10.py
+++ /dev/null
@@ -1,1830 +0,0 @@
-# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
-# (c) 2022 Kohya S. @kohya_ss
-
-# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
-# enable reg images in fine-tuning, add dataset_repeats option
-# v8: supports Diffusers 0.7.2
-# v9: add bucketing option
-# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
-
-import time
-from torch.autograd.function import Function
-import argparse
-import glob
-import itertools
-import math
-import os
-import random
-
-from tqdm import tqdm
-import torch
-from torchvision import transforms
-from accelerate import Accelerator
-from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-import albumentations as albu
-import numpy as np
-from PIL import Image
-import cv2
-from einops import rearrange
-from torch import einsum
-
-# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
-TOKENIZER_PATH = "openai/clip-vit-large-patch14"
-
-# StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-LAST_STATE_NAME = "last-state"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
-EPOCH_STATE_NAME = "epoch-{:06d}-state"
-
-
-def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
- max_width, max_height = max_reso
- max_area = (max_width // divisible) * (max_height // divisible)
-
- resos = set()
-
- size = int(math.sqrt(max_area)) * divisible
- resos.add((size, size))
-
- size = min_size
- while size <= max_size:
- width = size
- height = min(max_size, (max_area // (width // divisible)) * divisible)
- resos.add((width, height))
- resos.add((height, width))
- size += divisible
-
- resos = list(resos)
- resos.sort()
-
- aspect_ratios = [w / h for w, h in resos]
- return resos, aspect_ratios
-
-
-class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
- def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
- super().__init__()
-
- self.batch_size = batch_size
- self.fine_tuning = fine_tuning
- self.train_img_path_captions = train_img_path_captions
- self.reg_img_path_captions = reg_img_path_captions
- self.tokenizer = tokenizer
- self.width, self.height = resolution
- self.size = min(self.width, self.height) # 短いほう
- self.prior_loss_weight = prior_loss_weight
- self.face_crop_aug_range = face_crop_aug_range
- self.random_crop = random_crop
- self.debug_dataset = debug_dataset
- self.shuffle_caption = shuffle_caption
- self.disable_padding = disable_padding
- self.latents_cache = None
- self.enable_bucket = False
-
- # augmentation
- flip_p = 0.5 if flip_aug else 0.0
- if color_aug:
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る
- self.aug = albu.Compose([
- albu.OneOf([
- # albu.RandomBrightnessContrast(0.05, 0.05, p=.2),
- albu.HueSaturationValue(5, 8, 0, p=.2),
- # albu.RGBShift(5, 5, 5, p=.1),
- albu.RandomGamma((95, 105), p=.5),
- ], p=.33),
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- elif flip_aug:
- self.aug = albu.Compose([
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- else:
- self.aug = None
-
- self.num_train_images = len(self.train_img_path_captions)
- self.num_reg_images = len(self.reg_img_path_captions)
-
- self.enable_reg_images = self.num_reg_images > 0
-
- if self.enable_reg_images and self.num_train_images < self.num_reg_images:
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
-
- self.image_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
- def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
- self.enable_bucket = enable_bucket
-
- cache_latents = vae is not None
- if cache_latents:
- if enable_bucket:
- print("cache latents with bucketing")
- else:
- print("cache latents")
- else:
- if enable_bucket:
- print("make buckets")
- else:
- print("prepare dataset")
-
- # bucketingを用意する
- if enable_bucket:
- bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
- else:
- # bucketはひとつだけ、すべての画像は同じ解像度
- bucket_resos = [(self.width, self.height)]
- bucket_aspect_ratios = [self.width / self.height]
- bucket_aspect_ratios = np.array(bucket_aspect_ratios)
-
- # 画像の解像度、latentをあらかじめ取得する
- img_ar_errors = []
- self.size_lat_cache = {}
- for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
- if image_path in self.size_lat_cache:
- continue
-
- image = self.load_image(image_path)[0]
- image_height, image_width = image.shape[0:2]
-
- if not enable_bucket:
- # assert image_width == self.width and image_height == self.height, \
- # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
- reso = (self.width, self.height)
- else:
- # bucketを決める
- aspect_ratio = image_width / image_height
- ar_errors = bucket_aspect_ratios - aspect_ratio
- bucket_id = np.abs(ar_errors).argmin()
- reso = bucket_resos[bucket_id]
- ar_error = ar_errors[bucket_id]
- img_ar_errors.append(ar_error)
-
- if cache_latents:
- image = self.resize_and_trim(image, reso)
-
- # latentを取得する
- if cache_latents:
- img_tensor = self.image_transforms(image)
- img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
- latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
- else:
- latents = None
-
- self.size_lat_cache[image_path] = (reso, latents)
-
- # 画像をbucketに分割する
- self.buckets = [[] for _ in range(len(bucket_resos))]
- reso_to_index = {}
- for i, reso in enumerate(bucket_resos):
- reso_to_index[reso] = i
-
- def split_to_buckets(is_reg, img_path_captions):
- for image_path, caption in img_path_captions:
- reso, _ = self.size_lat_cache[image_path]
- bucket_index = reso_to_index[reso]
- self.buckets[bucket_index].append((is_reg, image_path, caption))
-
- split_to_buckets(False, self.train_img_path_captions)
-
- if self.enable_reg_images:
- l = []
- while len(l) < len(self.train_img_path_captions):
- l += self.reg_img_path_captions
- l = l[:len(self.train_img_path_captions)]
- split_to_buckets(True, l)
-
- if enable_bucket:
- print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
- for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
- print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
- img_ar_errors = np.array(img_ar_errors)
- print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
-
- # 参照用indexを作る
- self.buckets_indices = []
- for bucket_index, bucket in enumerate(self.buckets):
- batch_count = int(math.ceil(len(bucket) / self.batch_size))
- for batch_index in range(batch_count):
- self.buckets_indices.append((bucket_index, batch_index))
-
- self.shuffle_buckets()
- self._length = len(self.buckets_indices)
-
- # どのサイズにリサイズするか→トリミングする方向で
- def resize_and_trim(self, image, reso):
- image_height, image_width = image.shape[0:2]
- ar_img = image_width / image_height
- ar_reso = reso[0] / reso[1]
- if ar_img > ar_reso: # 横が長い→縦を合わせる
- scale = reso[1] / image_height
- else:
- scale = reso[0] / image_width
- resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
-
- image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
- if resized_size[0] > reso[0]:
- trim_size = resized_size[0] - reso[0]
- image = image[:, trim_size//2:trim_size//2 + reso[0]]
- elif resized_size[1] > reso[1]:
- trim_size = resized_size[1] - reso[1]
- image = image[trim_size//2:trim_size//2 + reso[1]]
- assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
- f"internal error, illegal trimmed size: {image.shape}, {reso}"
- return image
-
- def shuffle_buckets(self):
- random.shuffle(self.buckets_indices)
- for bucket in self.buckets:
- random.shuffle(bucket)
-
- def load_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- img = np.array(image, np.uint8)
-
- face_cx = face_cy = face_w = face_h = 0
- if self.face_crop_aug_range is not None:
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
- if len(tokens) >= 5:
- face_cx = int(tokens[-4])
- face_cy = int(tokens[-3])
- face_w = int(tokens[-2])
- face_h = int(tokens[-1])
-
- return img, face_cx, face_cy, face_w, face_h
-
- # いい感じに切り出す
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
- height, width = image.shape[0:2]
- if height == self.height and width == self.width:
- return image
-
- # 画像サイズはsizeより大きいのでリサイズする
- face_size = max(face_w, face_h)
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
- if min_scale >= max_scale: # range指定がmin==max
- scale = min_scale
- else:
- scale = random.uniform(min_scale, max_scale)
-
- nh = int(height * scale + .5)
- nw = int(width * scale + .5)
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
- face_cx = int(face_cx * scale + .5)
- face_cy = int(face_cy * scale + .5)
- height, width = nh, nw
-
- # 顔を中心として448*640とかへを切り出す
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
-
- if self.random_crop:
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
- else:
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
- if face_size > self.size // 10 and face_size >= 40:
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
-
- p1 = max(0, min(p1, length - target_size))
-
- if axis == 0:
- image = image[p1:p1 + target_size, :]
- else:
- image = image[:, p1:p1 + target_size]
-
- return image
-
- def __len__(self):
- return self._length
-
- def __getitem__(self, index):
- if index == 0:
- self.shuffle_buckets()
-
- bucket = self.buckets[self.buckets_indices[index][0]]
- image_index = self.buckets_indices[index][1] * self.batch_size
-
- latents_list = []
- images = []
- captions = []
- loss_weights = []
-
- for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
- loss_weights.append(1.0 if is_reg else self.prior_loss_weight)
-
- # image/latentsを処理する
- reso, latents = self.size_lat_cache[image_path]
-
- if latents is None:
- # 画像を読み込み必要ならcropする
- img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
- im_h, im_w = img.shape[0:2]
-
- if self.enable_bucket:
- img = self.resize_and_trim(img, reso)
- else:
- if face_cx > 0: # 顔位置情報あり
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
- elif im_h > self.height or im_w > self.width:
- assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください"
- if im_h > self.height:
- p = random.randint(0, im_h - self.height)
- img = img[p:p + self.height]
- if im_w > self.width:
- p = random.randint(0, im_w - self.width)
- img = img[:, p:p + self.width]
-
- im_h, im_w = img.shape[0:2]
- assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
-
- # augmentation
- if self.aug is not None:
- img = self.aug(image=img)['image']
-
- image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
- else:
- image = None
-
- images.append(image)
- latents_list.append(latents)
-
- # captionを処理する
- if self.shuffle_caption: # captionのshuffleをする
- tokens = caption.strip().split(",")
- random.shuffle(tokens)
- caption = ",".join(tokens).strip()
- captions.append(caption)
-
- # input_idsをpadしてTensor変換
- if self.disable_padding:
- # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
- input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
- else:
- # paddingする
- input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
-
- example = {}
- example['loss_weights'] = torch.FloatTensor(loss_weights)
- example['input_ids'] = input_ids
- if images[0] is not None:
- images = torch.stack(images)
- images = images.to(memory_format=torch.contiguous_format).float()
- else:
- images = None
- example['images'] = images
- example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
- if self.debug_dataset:
- example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
- example['captions'] = captions
- return example
-
-
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def convert_ldm_unet_checkpoint(checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint(checkpoint):
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
-
- keys = list(checkpoint.keys())
-
- text_model_dict = {}
-
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
-
- text_model.load_state_dict(text_model_dict)
-
- return text_model
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def convert_unet_state_dict(unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(ckpt_path):
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config()
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- unet.load_state_dict(converted_unet_checkpoint)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- vae.load_state_dict(converted_vae_checkpoint)
-
- # convert text_model
- text_model = convert_ldm_clip_checkpoint(state_dict)
-
- return text_model, vae, unet
-
-
-def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict(unet.state_dict())
- for k, v in unet_state_dict.items():
- key = "model.diffusion_model." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Convert the text encoder model
- text_enc_dict = text_encoder.state_dict() # 変換不要
- for k, v in text_enc_dict.items():
- key = "cond_stage_model.transformer." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-# endregion
-
-
-def collate_fn(examples):
- return examples[0]
-
-
-def train(args):
- fine_tuning = args.fine_tuning
- cache_latents = args.cache_latents
-
- # latentsをキャッシュする場合のオプション設定を確認する
- if cache_latents:
- # assert args.face_crop_aug_range is None and not args.random_crop, "when caching latents, crop aug cannot be used / latentをキャッシュするときは切り出しは使えません"
- # →使えるようにしておく(初期イメージの切り出しになる)
- assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません"
-
- # モデル形式のオプション設定を確認する
- use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
- if not use_stable_diffusion_format:
- assert os.path.exists(
- args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
-
- assert args.save_every_n_epochs is None or use_stable_diffusion_format, "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
-
- if args.seed is not None:
- set_seed(args.seed)
-
- # 学習データを用意する
- def read_caption(img_path):
- # captionの候補ファイル名を作る
- base_name = os.path.splitext(img_path)[0]
- base_name_face_det = base_name
- tokens = base_name.split("_")
- if len(tokens) >= 5:
- base_name_face_det = "_".join(tokens[:-4])
- cap_paths = [base_name + args.caption_extention, base_name_face_det + args.caption_extention]
-
- caption = None
- for cap_path in cap_paths:
- if os.path.isfile(cap_path):
- with open(cap_path, "rt", encoding='utf-8') as f:
- caption = f.readlines()[0].strip()
- break
- return caption
-
- def load_dreambooth_dir(dir):
- tokens = os.path.basename(dir).split('_')
- try:
- n_repeats = int(tokens[0])
- except ValueError as e:
- return 0, []
-
- caption = '_'.join(tokens[1:])
-
- print(f"found directory {n_repeats}_{caption}")
-
- img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
- glob.glob(os.path.join(dir, "*.webp"))
-
- # 画像ファイルごとにプロンプトを読み込み、もしあれば連結する
- captions = []
- for img_path in img_paths:
- cap_for_img = read_caption(img_path)
- captions.append(caption + ("" if cap_for_img is None else cap_for_img))
-
- return n_repeats, list(zip(img_paths, captions))
-
- print("prepare train images.")
- train_img_path_captions = []
-
- if fine_tuning:
- img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
- glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
- for img_path in tqdm(img_paths):
- caption = read_caption(img_path)
- assert caption is not None and len(
- caption) > 0, f"no caption for image. check caption_extention option / キャプションファイルが見つからないかcaptionが空です。caption_extentionオプションを確認してください: {img_path}"
-
- train_img_path_captions.append((img_path, caption))
-
- if args.dataset_repeats is not None:
- l = []
- for _ in range(args.dataset_repeats):
- l.extend(train_img_path_captions)
- train_img_path_captions = l
- else:
- train_dirs = os.listdir(args.train_data_dir)
- for dir in train_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
- for _ in range(n_repeats):
- train_img_path_captions.extend(img_caps)
- print(f"{len(train_img_path_captions)} train images with repeating.")
-
- reg_img_path_captions = []
- if args.reg_data_dir:
- print("prepare reg images.")
- reg_dirs = os.listdir(args.reg_data_dir)
- for dir in reg_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir))
- for _ in range(n_repeats):
- reg_img_path_captions.extend(img_caps)
- print(f"{len(reg_img_path_captions)} reg images.")
-
- # データセットを準備する
- resolution = tuple([int(r) for r in args.resolution.split(',')])
- if len(resolution) == 1:
- resolution = (resolution[0], resolution[0])
- assert len(resolution) == 2, \
- f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
-
- if args.enable_bucket:
- assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください"
- assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください"
-
- if args.face_crop_aug_range is not None:
- face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
- assert len(
- face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
- else:
- face_crop_aug_range = None
-
- # tokenizerを読み込む
- print("prepare tokenizer")
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
-
- print("prepare dataset")
- train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
- args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
- args.shuffle_caption, args.no_token_padding, args.debug_dataset)
-
- if args.debug_dataset:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso,
- args.max_bucket_reso) # デバッグ用にcacheなしで作る
- print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
- print("Escape for exit. / Escキーで中断、終了します")
- for example in train_dataset:
- for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
- print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
- cv2.imshow("img", im)
- k = cv2.waitKey()
- cv2.destroyAllWindows()
- if k == 27:
- break
- if k == 27:
- break
- return
-
- # acceleratorを準備する
- # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
- print("prepare accelerator")
- if args.logging_dir is None:
- log_with = None
- logging_dir = None
- else:
- log_with = "tensorboard"
- logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
- accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
- log_with=log_with, logging_dir=logging_dir)
-
- # モデルを読み込む
- if use_stable_diffusion_format:
- print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path)
- else:
- print("load Diffusers pretrained models")
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
- unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
-
- # モデルに xformers とか memory efficient attention を組み込む
- replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
-
- # mixed precisionに対応した型を用意しておき適宜castする
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- save_dtype = None
- if args.save_precision == "fp16":
- save_dtype = torch.float16
- elif args.save_precision == "bf16":
- save_dtype = torch.bfloat16
- elif args.save_precision == "float":
- save_dtype = torch.float32
-
- # 学習を準備する
- if cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
- vae.requires_grad_(False)
- vae.eval()
- with torch.no_grad():
- train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
- del vae
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- else:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
- vae.requires_grad_(False)
- vae.eval()
-
- unet.requires_grad_(True) # 念のため追加
- text_encoder.requires_grad_(True)
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- text_encoder.gradient_checkpointing_enable()
-
- # 学習に必要なクラスを準備する
- print("prepare optimizer, data loader etc.")
-
- # 8-bit Adamを使う
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
- print("use 8-bit Adam optimizer")
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
-
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
-
- # dataloaderを準備する
- # DataLoaderのプロセス数:0はメインプロセスになる
- n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
-
- # lr schedulerを用意する
- lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
-
- # acceleratorがなんかよろしくやってくれるらしい
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
-
- if not cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # resumeする
- if args.resume is not None:
- print(f"resume training from state: {args.resume}")
- accelerator.load_state(args.resume)
-
- # epoch数を計算する
- num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
-
- # 学習する
- total_batch_size = args.train_batch_size # * accelerator.num_processes
- print("running training / 学習開始")
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
- print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}")
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
- print(f" num epochs / epoch数: {num_train_epochs}")
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
- print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
-
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
- global_step = 0
-
- noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
-
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
-
- # 以下 train_dreambooth.py からほぼコピペ
- for epoch in range(num_train_epochs):
- print(f"epoch {epoch+1}/{num_train_epochs}")
- unet.train()
- text_encoder.train() # なんかunetだけでいいらしい?→最新版で修正されてた(;´Д`) いろいろ雑だな
-
- loss_total = 0
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet):
- with torch.no_grad():
- # latentに変換
- if cache_latents:
- latents = batch["latents"].to(accelerator.device)
- else:
- latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents, device=latents.device)
- b_size = latents.shape[0]
-
- # Sample a random timestep for each image
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- if args.clip_skip is None:
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
- else:
- enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
-
- # Predict the noise residual
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="none")
- loss = loss.mean([1, 2, 3])
-
- loss_weights = batch["loss_weights"] # 各sampleごとのweight
- loss = loss * loss_weights
-
- loss = loss.mean()
-
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=True)
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- current_loss = loss.detach().item()
- if args.logging_dir is not None:
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
- accelerator.log(logs, step=global_step)
-
- loss_total += current_loss
- avr_loss = loss_total / (step+1)
- logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
- if args.logging_dir is not None:
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
- accelerator.log(logs, step=epoch+1)
-
- accelerator.wait_for_everyone()
-
- if use_stable_diffusion_format and args.save_every_n_epochs is not None:
- if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
- print("saving check point.")
- os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
-
- if args.save_state:
- print("saving state.")
- accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
-
- is_main_process = accelerator.is_main_process
- if is_main_process:
- unet = accelerator.unwrap_model(unet)
- text_encoder = accelerator.unwrap_model(text_encoder)
-
- accelerator.end_training()
-
- if args.save_state:
- print("saving last state.")
- accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
-
- del accelerator # この後メモリを使うのでこれは消す
-
- if is_main_process:
- os.makedirs(args.output_dir, exist_ok=True)
- if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet,
- args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
- else:
- # Create the pipeline using using the trained modules and save it.
- print(f"save trained model as Diffusers to {args.output_dir}")
- pipeline = StableDiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- text_encoder=text_encoder,
- )
- pipeline.save_pretrained(args.output_dir)
- print("model saved.")
-
-
-# region モジュール入れ替え部
-"""
-高速化のためのモジュール入れ替え
-"""
-
-# FlashAttentionを使うCrossAttention
-# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
-# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
-
-# constants
-
-EPSILON = 1e-6
-
-# helper functions
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-# flash attention forwards and backwards
-
-# https://arxiv.org/abs/2205.14135
-
-
-class FlashAttentionFunction(Function):
- @ staticmethod
- @ torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
-
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
-
- scale = (q.shape[-1] ** -0.5)
-
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim=-1)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- mask,
- all_row_sums.split(q_bucket_size, dim=-2),
- all_row_maxes.split(q_bucket_size, dim=-2),
- )
-
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
-
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
-
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
-
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
-
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
-
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
-
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
-
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
-
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
-
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
-
- return o
-
- @ staticmethod
- @ torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
-
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
-
- device = q.device
-
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- do.split(q_bucket_size, dim=-2),
- mask,
- l.split(q_bucket_size, dim=-2),
- m.split(q_bucket_size, dim=-2),
- dq.split(q_bucket_size, dim=-2)
- )
-
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- dk.split(k_bucket_size, dim=-2),
- dv.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- exp_attn_weights = torch.exp(attn_weights - mc)
-
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
-
- p = exp_attn_weights / lc
-
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
-
- D = (doc * oc).sum(dim=-1, keepdims=True)
- ds = p * scale * (dp - D)
-
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
-
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
-
- return dq, dk, dv, None, None, None, None
-
-
-def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
- if mem_eff_attn:
- replace_unet_cross_attn_to_memory_efficient()
- elif xformers:
- replace_unet_cross_attn_to_xformers()
-
-
-def replace_unet_cross_attn_to_memory_efficient():
- print("Replace CrossAttention.forward to use FlashAttention")
- flash_func = FlashAttentionFunction
-
- def forward_flash_attn(self, x, context=None, mask=None):
- q_bucket_size = 512
- k_bucket_size = 1024
-
- h = self.heads
- q = self.to_q(x)
-
- context = context if context is not None else x
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
-
- out = rearrange(out, 'b h n d -> b n (h d)')
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
-
-
-def replace_unet_cross_attn_to_xformers():
- print("Replace CrossAttention.forward to use xformers")
- try:
- import xformers.ops
- except ImportError:
- raise ImportError("No xformers / xformersがインストールされていないようです")
-
- def forward_xformers(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
-
- context = default(context, x)
- context = context.to(x.dtype)
-
- k_in = self.to_k(context)
- v_in = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
- # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
- del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
- # out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_xformers
-# endregion
-
-
-if __name__ == '__main__':
- # torch.cuda.set_per_process_memory_fraction(0.48)
- parser = argparse.ArgumentParser()
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
- parser.add_argument("--fine_tuning", action="store_true",
- help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
- parser.add_argument("--shuffle_caption", action="store_true",
- help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
- parser.add_argument("--caption_extention", type=str, default=".caption", help="extention of caption files / 読み込むcaptionファイルの拡張子")
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
- parser.add_argument("--dataset_repeats", type=int, default=None,
- help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
- parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
- help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
- parser.add_argument("--save_state", action="store_true",
- help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
- parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
- parser.add_argument("--no_token_padding", action="store_true",
- help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
- parser.add_argument("--random_crop", action="store_true",
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
- parser.add_argument("--debug_dataset", action="store_true",
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
- parser.add_argument("--resolution", type=str, default=None,
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
- parser.add_argument("--train_batch_size", type=int, default=1,
- help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)")
- parser.add_argument("--use_8bit_adam", action="store_true",
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
- parser.add_argument("--mem_eff_attn", action="store_true",
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
- parser.add_argument("--xformers", action="store_true",
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
- parser.add_argument("--cache_latents", action="store_true",
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
- parser.add_argument("--enable_bucket", action="store_true",
- help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
- parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
- parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
- parser.add_argument("--gradient_checkpointing", action="store_true",
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
- parser.add_argument("--mixed_precision", type=str, default="no",
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
- parser.add_argument("--save_precision", type=str, default=None,
- choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
- parser.add_argument("--clip_skip", type=int, default=None,
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
- parser.add_argument("--logging_dir", type=str, default=None,
- help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
-
- args = parser.parse_args()
- train(args)
diff --git a/train_db_fixed/train_db_fixed_v11.py b/train_db_fixed/train_db_fixed_v11.py
deleted file mode 100644
index 25e39fab..00000000
--- a/train_db_fixed/train_db_fixed_v11.py
+++ /dev/null
@@ -1,2098 +0,0 @@
-# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
-# (c) 2022 Kohya S. @kohya_ss
-
-# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
-# enable reg images in fine-tuning, add dataset_repeats option
-# v8: supports Diffusers 0.7.2
-# v9: add bucketing option
-# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
-# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization
-# add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface
-# support save_ever_n_epochs/save_state in DiffUsers model
-# fix the issue that prior_loss_weight is applyed to train images
-
-import time
-from torch.autograd.function import Function
-import argparse
-import glob
-import itertools
-import math
-import os
-import random
-
-from tqdm import tqdm
-import torch
-from torchvision import transforms
-from accelerate import Accelerator
-from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-import albumentations as albu
-import numpy as np
-from PIL import Image
-import cv2
-from einops import rearrange
-from torch import einsum
-
-# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
-TOKENIZER_PATH = "openai/clip-vit-large-patch14"
-V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
-
-# DiffUsers版StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# V2
-V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
-V2_UNET_PARAMS_CONTEXT_DIM = 1024
-
-# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-LAST_STATE_NAME = "last-state"
-LAST_DIFFUSERS_DIR_NAME = "last"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
-EPOCH_STATE_NAME = "epoch-{:06d}-state"
-EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
-
-
-# region dataset
-
-
-def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
- max_width, max_height = max_reso
- max_area = (max_width // divisible) * (max_height // divisible)
-
- resos = set()
-
- size = int(math.sqrt(max_area)) * divisible
- resos.add((size, size))
-
- size = min_size
- while size <= max_size:
- width = size
- height = min(max_size, (max_area // (width // divisible)) * divisible)
- resos.add((width, height))
- resos.add((height, width))
- size += divisible
-
- resos = list(resos)
- resos.sort()
-
- aspect_ratios = [w / h for w, h in resos]
- return resos, aspect_ratios
-
-
-class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
- def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
- super().__init__()
-
- self.batch_size = batch_size
- self.fine_tuning = fine_tuning
- self.train_img_path_captions = train_img_path_captions
- self.reg_img_path_captions = reg_img_path_captions
- self.tokenizer = tokenizer
- self.width, self.height = resolution
- self.size = min(self.width, self.height) # 短いほう
- self.prior_loss_weight = prior_loss_weight
- self.face_crop_aug_range = face_crop_aug_range
- self.random_crop = random_crop
- self.debug_dataset = debug_dataset
- self.shuffle_caption = shuffle_caption
- self.disable_padding = disable_padding
- self.latents_cache = None
- self.enable_bucket = False
-
- # augmentation
- flip_p = 0.5 if flip_aug else 0.0
- if color_aug:
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る
- self.aug = albu.Compose([
- albu.OneOf([
- # albu.RandomBrightnessContrast(0.05, 0.05, p=.2),
- albu.HueSaturationValue(5, 8, 0, p=.2),
- # albu.RGBShift(5, 5, 5, p=.1),
- albu.RandomGamma((95, 105), p=.5),
- ], p=.33),
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- elif flip_aug:
- self.aug = albu.Compose([
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- else:
- self.aug = None
-
- self.num_train_images = len(self.train_img_path_captions)
- self.num_reg_images = len(self.reg_img_path_captions)
-
- self.enable_reg_images = self.num_reg_images > 0
-
- if self.enable_reg_images and self.num_train_images < self.num_reg_images:
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
-
- self.image_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
- def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
- self.enable_bucket = enable_bucket
-
- cache_latents = vae is not None
- if cache_latents:
- if enable_bucket:
- print("cache latents with bucketing")
- else:
- print("cache latents")
- else:
- if enable_bucket:
- print("make buckets")
- else:
- print("prepare dataset")
-
- # bucketingを用意する
- if enable_bucket:
- bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
- else:
- # bucketはひとつだけ、すべての画像は同じ解像度
- bucket_resos = [(self.width, self.height)]
- bucket_aspect_ratios = [self.width / self.height]
- bucket_aspect_ratios = np.array(bucket_aspect_ratios)
-
- # 画像の解像度、latentをあらかじめ取得する
- img_ar_errors = []
- self.size_lat_cache = {}
- for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
- if image_path in self.size_lat_cache:
- continue
-
- image = self.load_image(image_path)[0]
- image_height, image_width = image.shape[0:2]
-
- if not enable_bucket:
- # assert image_width == self.width and image_height == self.height, \
- # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
- reso = (self.width, self.height)
- else:
- # bucketを決める
- aspect_ratio = image_width / image_height
- ar_errors = bucket_aspect_ratios - aspect_ratio
- bucket_id = np.abs(ar_errors).argmin()
- reso = bucket_resos[bucket_id]
- ar_error = ar_errors[bucket_id]
- img_ar_errors.append(ar_error)
-
- if cache_latents:
- image = self.resize_and_trim(image, reso)
-
- # latentを取得する
- if cache_latents:
- img_tensor = self.image_transforms(image)
- img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
- latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
- else:
- latents = None
-
- self.size_lat_cache[image_path] = (reso, latents)
-
- # 画像をbucketに分割する
- self.buckets = [[] for _ in range(len(bucket_resos))]
- reso_to_index = {}
- for i, reso in enumerate(bucket_resos):
- reso_to_index[reso] = i
-
- def split_to_buckets(is_reg, img_path_captions):
- for image_path, caption in img_path_captions:
- reso, _ = self.size_lat_cache[image_path]
- bucket_index = reso_to_index[reso]
- self.buckets[bucket_index].append((is_reg, image_path, caption))
-
- split_to_buckets(False, self.train_img_path_captions)
-
- if self.enable_reg_images:
- l = []
- while len(l) < len(self.train_img_path_captions):
- l += self.reg_img_path_captions
- l = l[:len(self.train_img_path_captions)]
- split_to_buckets(True, l)
-
- if enable_bucket:
- print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
- for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
- print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
- img_ar_errors = np.array(img_ar_errors)
- print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
-
- # 参照用indexを作る
- self.buckets_indices = []
- for bucket_index, bucket in enumerate(self.buckets):
- batch_count = int(math.ceil(len(bucket) / self.batch_size))
- for batch_index in range(batch_count):
- self.buckets_indices.append((bucket_index, batch_index))
-
- self.shuffle_buckets()
- self._length = len(self.buckets_indices)
-
- # どのサイズにリサイズするか→トリミングする方向で
- def resize_and_trim(self, image, reso):
- image_height, image_width = image.shape[0:2]
- ar_img = image_width / image_height
- ar_reso = reso[0] / reso[1]
- if ar_img > ar_reso: # 横が長い→縦を合わせる
- scale = reso[1] / image_height
- else:
- scale = reso[0] / image_width
- resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
-
- image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
- if resized_size[0] > reso[0]:
- trim_size = resized_size[0] - reso[0]
- image = image[:, trim_size//2:trim_size//2 + reso[0]]
- elif resized_size[1] > reso[1]:
- trim_size = resized_size[1] - reso[1]
- image = image[trim_size//2:trim_size//2 + reso[1]]
- assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
- f"internal error, illegal trimmed size: {image.shape}, {reso}"
- return image
-
- def shuffle_buckets(self):
- random.shuffle(self.buckets_indices)
- for bucket in self.buckets:
- random.shuffle(bucket)
-
- def load_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- img = np.array(image, np.uint8)
-
- face_cx = face_cy = face_w = face_h = 0
- if self.face_crop_aug_range is not None:
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
- if len(tokens) >= 5:
- face_cx = int(tokens[-4])
- face_cy = int(tokens[-3])
- face_w = int(tokens[-2])
- face_h = int(tokens[-1])
-
- return img, face_cx, face_cy, face_w, face_h
-
- # いい感じに切り出す
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
- height, width = image.shape[0:2]
- if height == self.height and width == self.width:
- return image
-
- # 画像サイズはsizeより大きいのでリサイズする
- face_size = max(face_w, face_h)
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
- if min_scale >= max_scale: # range指定がmin==max
- scale = min_scale
- else:
- scale = random.uniform(min_scale, max_scale)
-
- nh = int(height * scale + .5)
- nw = int(width * scale + .5)
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
- face_cx = int(face_cx * scale + .5)
- face_cy = int(face_cy * scale + .5)
- height, width = nh, nw
-
- # 顔を中心として448*640とかへを切り出す
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
-
- if self.random_crop:
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
- else:
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
- if face_size > self.size // 10 and face_size >= 40:
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
-
- p1 = max(0, min(p1, length - target_size))
-
- if axis == 0:
- image = image[p1:p1 + target_size, :]
- else:
- image = image[:, p1:p1 + target_size]
-
- return image
-
- def __len__(self):
- return self._length
-
- def __getitem__(self, index):
- if index == 0:
- self.shuffle_buckets()
-
- bucket = self.buckets[self.buckets_indices[index][0]]
- image_index = self.buckets_indices[index][1] * self.batch_size
-
- latents_list = []
- images = []
- captions = []
- loss_weights = []
-
- for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
- loss_weights.append(self.prior_loss_weight if is_reg else 1.0)
-
- # image/latentsを処理する
- reso, latents = self.size_lat_cache[image_path]
-
- if latents is None:
- # 画像を読み込み必要ならcropする
- img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
- im_h, im_w = img.shape[0:2]
-
- if self.enable_bucket:
- img = self.resize_and_trim(img, reso)
- else:
- if face_cx > 0: # 顔位置情報あり
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
- elif im_h > self.height or im_w > self.width:
- assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください"
- if im_h > self.height:
- p = random.randint(0, im_h - self.height)
- img = img[p:p + self.height]
- if im_w > self.width:
- p = random.randint(0, im_w - self.width)
- img = img[:, p:p + self.width]
-
- im_h, im_w = img.shape[0:2]
- assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
-
- # augmentation
- if self.aug is not None:
- img = self.aug(image=img)['image']
-
- image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
- else:
- image = None
-
- images.append(image)
- latents_list.append(latents)
-
- # captionを処理する
- if self.shuffle_caption: # captionのshuffleをする
- tokens = caption.strip().split(",")
- random.shuffle(tokens)
- caption = ",".join(tokens).strip()
- captions.append(caption)
-
- # input_idsをpadしてTensor変換
- if self.disable_padding:
- # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
- input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
- else:
- # paddingする
- input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
-
- example = {}
- example['loss_weights'] = torch.FloatTensor(loss_weights)
- example['input_ids'] = input_ids
- if images[0] is not None:
- images = torch.stack(images)
- images = images.to(memory_format=torch.contiguous_format).float()
- else:
- images = None
- example['images'] = images
- example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
- if self.debug_dataset:
- example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
- example['captions'] = captions
- return example
-# endregion
-
-
-# region モジュール入れ替え部
-"""
-高速化のためのモジュール入れ替え
-"""
-
-# FlashAttentionを使うCrossAttention
-# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
-# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
-
-# constants
-
-EPSILON = 1e-6
-
-# helper functions
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-# flash attention forwards and backwards
-
-# https://arxiv.org/abs/2205.14135
-
-
-class FlashAttentionFunction(Function):
- @ staticmethod
- @ torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
-
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
-
- scale = (q.shape[-1] ** -0.5)
-
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim=-1)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- mask,
- all_row_sums.split(q_bucket_size, dim=-2),
- all_row_maxes.split(q_bucket_size, dim=-2),
- )
-
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
-
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
-
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
-
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
-
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
-
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
-
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
-
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
-
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
-
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
-
- return o
-
- @ staticmethod
- @ torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
-
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
-
- device = q.device
-
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- do.split(q_bucket_size, dim=-2),
- mask,
- l.split(q_bucket_size, dim=-2),
- m.split(q_bucket_size, dim=-2),
- dq.split(q_bucket_size, dim=-2)
- )
-
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- dk.split(k_bucket_size, dim=-2),
- dv.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- exp_attn_weights = torch.exp(attn_weights - mc)
-
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
-
- p = exp_attn_weights / lc
-
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
-
- D = (doc * oc).sum(dim=-1, keepdims=True)
- ds = p * scale * (dp - D)
-
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
-
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
-
- return dq, dk, dv, None, None, None, None
-
-
-def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
- if mem_eff_attn:
- replace_unet_cross_attn_to_memory_efficient()
- elif xformers:
- replace_unet_cross_attn_to_xformers()
-
-
-def replace_unet_cross_attn_to_memory_efficient():
- print("Replace CrossAttention.forward to use FlashAttention")
- flash_func = FlashAttentionFunction
-
- def forward_flash_attn(self, x, context=None, mask=None):
- q_bucket_size = 512
- k_bucket_size = 1024
-
- h = self.heads
- q = self.to_q(x)
-
- context = context if context is not None else x
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
-
- out = rearrange(out, 'b h n d -> b n (h d)')
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
-
-
-def replace_unet_cross_attn_to_xformers():
- print("Replace CrossAttention.forward to use xformers")
- try:
- import xformers.ops
- except ImportError:
- raise ImportError("No xformers / xformersがインストールされていないようです")
-
- def forward_xformers(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
-
- context = default(context, x)
- context = context.to(x.dtype)
-
- k_in = self.to_k(context)
- v_in = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
- # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
- del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
- # out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_xformers
-# endregion
-
-
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def linear_transformer_to_conv(checkpoint):
- keys = list(checkpoint.keys())
- tf_keys = ["proj_in.weight", "proj_out.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in tf_keys:
- if checkpoint[key].ndim == 2:
- checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
-
-
-def convert_ldm_unet_checkpoint(v2, checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
- if v2:
- linear_transformer_to_conv(new_checkpoint)
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config(v2):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint_v1(checkpoint):
- keys = list(checkpoint.keys())
- text_model_dict = {}
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
- return text_model_dict
-
-
-def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
- # 嫌になるくらい違うぞ!
- def convert_key(key):
- if not key.startswith("cond_stage_model"):
- return None
-
- # common conversion
- key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
- key = key.replace("cond_stage_model.model.", "text_model.")
-
- if "resblocks" in key:
- # resblocks conversion
- key = key.replace(".resblocks.", ".layers.")
- if ".ln_" in key:
- key = key.replace(".ln_", ".layer_norm")
- elif ".mlp." in key:
- key = key.replace(".c_fc.", ".fc1.")
- key = key.replace(".c_proj.", ".fc2.")
- elif '.attn.out_proj' in key:
- key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
- elif '.attn.in_proj' in key:
- key = None # 特殊なので後で処理する
- else:
- raise ValueError(f"unexpected key in SD: {key}")
- elif '.positional_embedding' in key:
- key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
- elif '.text_projection' in key:
- key = None # 使われない???
- elif '.logit_scale' in key:
- key = None # 使われない???
- elif '.token_embedding' in key:
- key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
- elif '.ln_final' in key:
- key = key.replace(".ln_final", ".final_layer_norm")
- return key
-
- keys = list(checkpoint.keys())
- new_sd = {}
- for key in keys:
- # remove resblocks 23
- if '.resblocks.23.' in key:
- continue
- new_key = convert_key(key)
- if new_key is None:
- continue
- new_sd[new_key] = checkpoint[key]
-
- # attnの変換
- for key in keys:
- if '.resblocks.23.' in key:
- continue
- if '.resblocks' in key and '.attn.in_proj_' in key:
- # 三つに分割
- values = torch.chunk(checkpoint[key], 3)
-
- key_suffix = ".weight" if "weight" in key else ".bias"
- key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
- key_pfx = key_pfx.replace("_weight", "")
- key_pfx = key_pfx.replace("_bias", "")
- key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
- new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
- new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
- new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
-
- # position_idsの追加
- new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
- return new_sd
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def conv_transformer_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- tf_keys = ["proj_in.weight", "proj_out.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in tf_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
-
-
-def convert_unet_state_dict_to_sd(v2, unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
-
- if v2:
- conv_transformer_to_linear(new_state_dict)
-
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_text_encoder_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path):
- checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config(v2)
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- info = unet.load_state_dict(converted_unet_checkpoint)
- print("loading u-net:", info)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- info = vae.load_state_dict(converted_vae_checkpoint)
- print("loadint vae:", info)
-
- # convert text_model
- if v2:
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
- cfg = CLIPTextConfig(
- vocab_size=49408,
- hidden_size=1024,
- intermediate_size=4096,
- num_hidden_layers=23,
- num_attention_heads=16,
- max_position_embeddings=77,
- hidden_act="gelu",
- layer_norm_eps=1e-05,
- dropout=0.0,
- attention_dropout=0.0,
- initializer_range=0.02,
- initializer_factor=1.0,
- pad_token_id=1,
- bos_token_id=0,
- eos_token_id=2,
- model_type="clip_text_model",
- projection_dim=512,
- torch_dtype="float32",
- transformers_version="4.25.0.dev0",
- )
- text_model = CLIPTextModel._from_config(cfg)
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
- else:
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
- print("loading text encoder:", info)
-
- return text_model, vae, unet
-
-
-def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
- def convert_key(key):
- # position_idsの除去
- if ".position_ids" in key:
- return None
-
- # common
- key = key.replace("text_model.encoder.", "transformer.")
- key = key.replace("text_model.", "")
- if "layers" in key:
- # resblocks conversion
- key = key.replace(".layers.", ".resblocks.")
- if ".layer_norm" in key:
- key = key.replace(".layer_norm", ".ln_")
- elif ".mlp." in key:
- key = key.replace(".fc1.", ".c_fc.")
- key = key.replace(".fc2.", ".c_proj.")
- elif '.self_attn.out_proj' in key:
- key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
- elif '.self_attn.' in key:
- key = None # 特殊なので後で処理する
- else:
- raise ValueError(f"unexpected key in DiffUsers model: {key}")
- elif '.position_embedding' in key:
- key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
- elif '.token_embedding' in key:
- key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
- elif 'final_layer_norm' in key:
- key = key.replace("final_layer_norm", "ln_final")
- return key
-
- keys = list(checkpoint.keys())
- new_sd = {}
- for key in keys:
- new_key = convert_key(key)
- if new_key is None:
- continue
- new_sd[new_key] = checkpoint[key]
-
- # attnの変換
- for key in keys:
- if 'layers' in key and 'q_proj' in key:
- # 三つを結合
- key_q = key
- key_k = key.replace("q_proj", "k_proj")
- key_v = key.replace("q_proj", "v_proj")
-
- value_q = checkpoint[key_q]
- value_k = checkpoint[key_k]
- value_v = checkpoint[key_v]
- value = torch.cat([value_q, value_k, value_v])
-
- new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
- new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
- new_sd[new_key] = value
-
- return new_sd
-
-
-def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- def assign_new_sd(prefix, sd):
- for k, v in sd.items():
- key = prefix + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
- assign_new_sd("model.diffusion_model.", unet_state_dict)
-
- # Convert the text encoder model
- if v2:
- text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
- assign_new_sd("cond_stage_model.model.", text_enc_dict)
- else:
- text_enc_dict = text_encoder.state_dict()
- assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-
-
-def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
- pipeline = StableDiffusionPipeline(
- unet=unet,
- text_encoder=text_encoder,
- vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
- scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
- tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
- safety_checker=None,
- feature_extractor=None,
- requires_safety_checker=None,
- )
- pipeline.save_pretrained(output_dir)
-
-# endregion
-
-
-def collate_fn(examples):
- return examples[0]
-
-
-def train(args):
- if args.caption_extention is not None:
- args.caption_extension = args.caption_extention
- args.caption_extention = None
-
- fine_tuning = args.fine_tuning
- cache_latents = args.cache_latents
-
- # latentsをキャッシュする場合のオプション設定を確認する
- if cache_latents:
- assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません"
-
- # その他のオプション設定を確認する
- if args.v_parameterization and not args.v2:
- print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
- if args.v2 and args.clip_skip is not None:
- print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
-
- # モデル形式のオプション設定を確認する:
- # v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した
- use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
-
- # 乱数系列を初期化する
- if args.seed is not None:
- set_seed(args.seed)
-
- # 学習データを用意する
- def read_caption(img_path):
- # captionの候補ファイル名を作る
- base_name = os.path.splitext(img_path)[0]
- base_name_face_det = base_name
- tokens = base_name.split("_")
- if len(tokens) >= 5:
- base_name_face_det = "_".join(tokens[:-4])
- cap_paths = [base_name + args.caption_extension, base_name_face_det + args.caption_extension]
-
- caption = None
- for cap_path in cap_paths:
- if os.path.isfile(cap_path):
- with open(cap_path, "rt", encoding='utf-8') as f:
- caption = f.readlines()[0].strip()
- break
- return caption
-
- def load_dreambooth_dir(dir):
- tokens = os.path.basename(dir).split('_')
- try:
- n_repeats = int(tokens[0])
- except ValueError as e:
- return 0, []
-
- caption_by_folder = '_'.join(tokens[1:])
-
- print(f"found directory {n_repeats}_{caption_by_folder}")
-
- img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
- glob.glob(os.path.join(dir, "*.webp"))
-
- # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う(v11から仕様変更した)
- captions = []
- for img_path in img_paths:
- cap_for_img = read_caption(img_path)
- captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
-
- return n_repeats, list(zip(img_paths, captions))
-
- print("prepare train images.")
- train_img_path_captions = []
-
- if fine_tuning:
- img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
- glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
- for img_path in tqdm(img_paths):
- caption = read_caption(img_path)
- assert caption is not None and len(
- caption) > 0, f"no caption for image. check caption_extension option / キャプションファイルが見つからないかcaptionが空です。caption_extensionオプションを確認してください: {img_path}"
-
- train_img_path_captions.append((img_path, caption))
-
- if args.dataset_repeats is not None:
- l = []
- for _ in range(args.dataset_repeats):
- l.extend(train_img_path_captions)
- train_img_path_captions = l
- else:
- train_dirs = os.listdir(args.train_data_dir)
- for dir in train_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
- for _ in range(n_repeats):
- train_img_path_captions.extend(img_caps)
- print(f"{len(train_img_path_captions)} train images with repeating.")
-
- reg_img_path_captions = []
- if args.reg_data_dir:
- print("prepare reg images.")
- reg_dirs = os.listdir(args.reg_data_dir)
- for dir in reg_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir))
- for _ in range(n_repeats):
- reg_img_path_captions.extend(img_caps)
- print(f"{len(reg_img_path_captions)} reg images.")
-
- # データセットを準備する
- resolution = tuple([int(r) for r in args.resolution.split(',')])
- if len(resolution) == 1:
- resolution = (resolution[0], resolution[0])
- assert len(resolution) == 2, \
- f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
-
- if args.enable_bucket:
- assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください"
- assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください"
-
- if args.face_crop_aug_range is not None:
- face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
- assert len(
- face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
- else:
- face_crop_aug_range = None
-
- # tokenizerを読み込む
- print("prepare tokenizer")
- if args.v2:
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
- else:
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
-
- print("prepare dataset")
- train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
- args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
- args.shuffle_caption, args.no_token_padding, args.debug_dataset)
-
- if args.debug_dataset:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso,
- args.max_bucket_reso) # デバッグ用にcacheなしで作る
- print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
- print("Escape for exit. / Escキーで中断、終了します")
- for example in train_dataset:
- for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
- print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
- cv2.imshow("img", im)
- k = cv2.waitKey()
- cv2.destroyAllWindows()
- if k == 27:
- break
- if k == 27:
- break
- return
-
- # acceleratorを準備する
- # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
- print("prepare accelerator")
- if args.logging_dir is None:
- log_with = None
- logging_dir = None
- else:
- log_with = "tensorboard"
- logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
- accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
- log_with=log_with, logging_dir=logging_dir)
-
- # mixed precisionに対応した型を用意しておき適宜castする
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- save_dtype = None
- if args.save_precision == "fp16":
- save_dtype = torch.float16
- elif args.save_precision == "bf16":
- save_dtype = torch.bfloat16
- elif args.save_precision == "float":
- save_dtype = torch.float32
-
- # モデルを読み込む
- if use_stable_diffusion_format:
- print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
- else:
- print("load Diffusers pretrained models")
- pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
- # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
- text_encoder = pipe.text_encoder
- vae = pipe.vae
- unet = pipe.unet
- del pipe
-
- # モデルに xformers とか memory efficient attention を組み込む
- replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
-
- # 学習を準備する
- if cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
- vae.requires_grad_(False)
- vae.eval()
- with torch.no_grad():
- train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
- del vae
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- else:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
- vae.requires_grad_(False)
- vae.eval()
-
- unet.requires_grad_(True) # 念のため追加
- text_encoder.requires_grad_(True)
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- text_encoder.gradient_checkpointing_enable()
-
- # 学習に必要なクラスを準備する
- print("prepare optimizer, data loader etc.")
-
- # 8-bit Adamを使う
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
- print("use 8-bit Adam optimizer")
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
-
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
-
- # dataloaderを準備する
- # DataLoaderのプロセス数:0はメインプロセスになる
- n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
-
- # lr schedulerを用意する
- lr_scheduler = diffusers.optimization.get_scheduler(
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
-
- # acceleratorがなんかよろしくやってくれるらしい
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
-
- if not cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # resumeする
- if args.resume is not None:
- print(f"resume training from state: {args.resume}")
- accelerator.load_state(args.resume)
-
- # epoch数を計算する
- num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
-
- # 学習する
- total_batch_size = args.train_batch_size # * accelerator.num_processes
- print("running training / 学習開始")
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
- print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}")
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
- print(f" num epochs / epoch数: {num_train_epochs}")
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
- print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
-
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
- global_step = 0
-
- noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
-
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
-
- # 以下 train_dreambooth.py からほぼコピペ
- for epoch in range(num_train_epochs):
- print(f"epoch {epoch+1}/{num_train_epochs}")
- unet.train()
- text_encoder.train()
-
- loss_total = 0
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet):
- with torch.no_grad():
- # latentに変換
- if cache_latents:
- latents = batch["latents"].to(accelerator.device)
- else:
- latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents, device=latents.device)
- b_size = latents.shape[0]
-
- # Sample a random timestep for each image
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- if args.clip_skip is None:
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
- else:
- enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
-
- # Predict the noise residual
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- if args.v_parameterization:
- # v-parameterization training
- # こうしたい:
- # target = noise_scheduler.get_v(latents, noise, timesteps)
-
- # StabilityAiのddpm.pyのコード:
- # elif self.parameterization == "v":
- # target = self.get_v(x_start, noise, t)
- # ...
- # def get_v(self, x, noise, t):
- # return (
- # extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
- # extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
- # )
-
- # scheduling_ddim.pyのコード:
- # elif self.config.prediction_type == "v_prediction":
- # pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
- # # predict V
- # model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
-
- # これでいいかな?:
- alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
- beta_prod_t = 1 - alpha_prod_t
- alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
- beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
- target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
- else:
- target = noise
-
- loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
- loss = loss.mean([1, 2, 3])
-
- loss_weights = batch["loss_weights"] # 各sampleごとのweight
- loss = loss * loss_weights
-
- loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
-
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=True)
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- current_loss = loss.detach().item()
- if args.logging_dir is not None:
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
- accelerator.log(logs, step=global_step)
-
- loss_total += current_loss
- avr_loss = loss_total / (step+1)
- logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
- if args.logging_dir is not None:
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
- accelerator.log(logs, step=epoch+1)
-
- accelerator.wait_for_everyone()
-
- if args.save_every_n_epochs is not None:
- if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
- print("saving checkpoint.")
- if use_stable_diffusion_format:
- os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
- else:
- out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
- os.makedirs(out_dir, exist_ok=True)
- save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
- accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
-
- if args.save_state:
- print("saving state.")
- accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
-
- is_main_process = accelerator.is_main_process
- if is_main_process:
- unet = accelerator.unwrap_model(unet)
- text_encoder = accelerator.unwrap_model(text_encoder)
-
- accelerator.end_training()
-
- if args.save_state:
- print("saving last state.")
- accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
-
- del accelerator # この後メモリを使うのでこれは消す
-
- if is_main_process:
- os.makedirs(args.output_dir, exist_ok=True)
- if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
- args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
- else:
- # Create the pipeline using using the trained modules and save it.
- print(f"save trained model as Diffusers to {args.output_dir}")
- out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
- os.makedirs(out_dir, exist_ok=True)
- save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path, save_dtype)
- print("model saved.")
-
-
-if __name__ == '__main__':
- # torch.cuda.set_per_process_memory_fraction(0.48)
- parser = argparse.ArgumentParser()
- parser.add_argument("--v2", action='store_true',
- help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
- parser.add_argument("--v_parameterization", action='store_true',
- help='enable v-parameterization training / v-parameterization学習を有効にする')
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
- parser.add_argument("--fine_tuning", action="store_true",
- help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
- parser.add_argument("--shuffle_caption", action="store_true",
- help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
- parser.add_argument("--caption_extention", type=str, default=None,
- help="extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
- parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
- parser.add_argument("--dataset_repeats", type=int, default=None,
- help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
- parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model (default format is same to input) / 学習後のモデル出力先ディレクトリ(デフォルトの保存形式は読み込んだ形式と同じ)")
- # parser.add_argument("--save_as_sd", action='store_true',
- # help="save the model as StableDiffusion checkpoint / 学習後のモデルをStableDiffusionのcheckpointとして保存する")
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
- help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します")
- parser.add_argument("--save_state", action="store_true",
- help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
- parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
- parser.add_argument("--no_token_padding", action="store_true",
- help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
- parser.add_argument("--random_crop", action="store_true",
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
- parser.add_argument("--debug_dataset", action="store_true",
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
- parser.add_argument("--resolution", type=str, default=None,
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
- parser.add_argument("--train_batch_size", type=int, default=1,
- help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)")
- parser.add_argument("--use_8bit_adam", action="store_true",
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
- parser.add_argument("--mem_eff_attn", action="store_true",
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
- parser.add_argument("--xformers", action="store_true",
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
- parser.add_argument("--cache_latents", action="store_true",
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
- parser.add_argument("--enable_bucket", action="store_true",
- help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
- parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
- parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
- parser.add_argument("--gradient_checkpointing", action="store_true",
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
- parser.add_argument("--mixed_precision", type=str, default="no",
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
- parser.add_argument("--save_precision", type=str, default=None,
- choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
- parser.add_argument("--clip_skip", type=int, default=None,
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
- parser.add_argument("--logging_dir", type=str, default=None,
- help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
- parser.add_argument("--lr_scheduler", type=str, default="constant",
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
-
- args = parser.parse_args()
- train(args)
diff --git a/train_db_fixed/train_db_fixed_v12.py b/train_db_fixed/train_db_fixed_v12.py
deleted file mode 100644
index ef6fc63d..00000000
--- a/train_db_fixed/train_db_fixed_v12.py
+++ /dev/null
@@ -1,2116 +0,0 @@
-# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
-# (c) 2022 Kohya S. @kohya_ss
-
-# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
-# enable reg images in fine-tuning, add dataset_repeats option
-# v8: supports Diffusers 0.7.2
-# v9: add bucketing option
-# v10: add min_bucket_reso/max_bucket_reso options, read captions for train/reg images in DreamBooth
-# v11: Diffusers 0.9.0 is required. support for Stable Diffusion 2.0/v-parameterization
-# add lr scheduler options, change handling folder/file caption, support loading DiffUser model from Huggingface
-# support save_ever_n_epochs/save_state in DiffUsers model
-# fix the issue that prior_loss_weight is applyed to train images
-# v12: stop train text encode, tqdm smoothing
-
-import time
-from torch.autograd.function import Function
-import argparse
-import glob
-import itertools
-import math
-import os
-import random
-
-from tqdm import tqdm
-import torch
-from torchvision import transforms
-from accelerate import Accelerator
-from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-import albumentations as albu
-import numpy as np
-from PIL import Image
-import cv2
-from einops import rearrange
-from torch import einsum
-
-# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
-TOKENIZER_PATH = "openai/clip-vit-large-patch14"
-V2_STABLE_DIFFUSION_PATH = "stabilityai/stable-diffusion-2" # ここからtokenizerだけ使う
-
-# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-LAST_STATE_NAME = "last-state"
-LAST_DIFFUSERS_DIR_NAME = "last"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
-EPOCH_STATE_NAME = "epoch-{:06d}-state"
-EPOCH_DIFFUSERS_DIR_NAME = "epoch-{:06d}"
-
-
-# region dataset
-
-
-def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
- max_width, max_height = max_reso
- max_area = (max_width // divisible) * (max_height // divisible)
-
- resos = set()
-
- size = int(math.sqrt(max_area)) * divisible
- resos.add((size, size))
-
- size = min_size
- while size <= max_size:
- width = size
- height = min(max_size, (max_area // (width // divisible)) * divisible)
- resos.add((width, height))
- resos.add((height, width))
- size += divisible
-
- resos = list(resos)
- resos.sort()
-
- aspect_ratios = [w / h for w, h in resos]
- return resos, aspect_ratios
-
-
-class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
- def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
- super().__init__()
-
- self.batch_size = batch_size
- self.fine_tuning = fine_tuning
- self.train_img_path_captions = train_img_path_captions
- self.reg_img_path_captions = reg_img_path_captions
- self.tokenizer = tokenizer
- self.width, self.height = resolution
- self.size = min(self.width, self.height) # 短いほう
- self.prior_loss_weight = prior_loss_weight
- self.face_crop_aug_range = face_crop_aug_range
- self.random_crop = random_crop
- self.debug_dataset = debug_dataset
- self.shuffle_caption = shuffle_caption
- self.disable_padding = disable_padding
- self.latents_cache = None
- self.enable_bucket = False
-
- # augmentation
- flip_p = 0.5 if flip_aug else 0.0
- if color_aug:
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る
- self.aug = albu.Compose([
- albu.OneOf([
- # albu.RandomBrightnessContrast(0.05, 0.05, p=.2),
- albu.HueSaturationValue(5, 8, 0, p=.2),
- # albu.RGBShift(5, 5, 5, p=.1),
- albu.RandomGamma((95, 105), p=.5),
- ], p=.33),
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- elif flip_aug:
- self.aug = albu.Compose([
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- else:
- self.aug = None
-
- self.num_train_images = len(self.train_img_path_captions)
- self.num_reg_images = len(self.reg_img_path_captions)
-
- self.enable_reg_images = self.num_reg_images > 0
-
- if self.enable_reg_images and self.num_train_images < self.num_reg_images:
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
-
- self.image_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
- def make_buckets_with_caching(self, enable_bucket, vae, min_size, max_size):
- self.enable_bucket = enable_bucket
-
- cache_latents = vae is not None
- if cache_latents:
- if enable_bucket:
- print("cache latents with bucketing")
- else:
- print("cache latents")
- else:
- if enable_bucket:
- print("make buckets")
- else:
- print("prepare dataset")
-
- # bucketingを用意する
- if enable_bucket:
- bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height), min_size, max_size)
- else:
- # bucketはひとつだけ、すべての画像は同じ解像度
- bucket_resos = [(self.width, self.height)]
- bucket_aspect_ratios = [self.width / self.height]
- bucket_aspect_ratios = np.array(bucket_aspect_ratios)
-
- # 画像の解像度、latentをあらかじめ取得する
- img_ar_errors = []
- self.size_lat_cache = {}
- for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
- if image_path in self.size_lat_cache:
- continue
-
- image = self.load_image(image_path)[0]
- image_height, image_width = image.shape[0:2]
-
- if not enable_bucket:
- # assert image_width == self.width and image_height == self.height, \
- # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
- reso = (self.width, self.height)
- else:
- # bucketを決める
- aspect_ratio = image_width / image_height
- ar_errors = bucket_aspect_ratios - aspect_ratio
- bucket_id = np.abs(ar_errors).argmin()
- reso = bucket_resos[bucket_id]
- ar_error = ar_errors[bucket_id]
- img_ar_errors.append(ar_error)
-
- if cache_latents:
- image = self.resize_and_trim(image, reso)
-
- # latentを取得する
- if cache_latents:
- img_tensor = self.image_transforms(image)
- img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
- latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
- else:
- latents = None
-
- self.size_lat_cache[image_path] = (reso, latents)
-
- # 画像をbucketに分割する
- self.buckets = [[] for _ in range(len(bucket_resos))]
- reso_to_index = {}
- for i, reso in enumerate(bucket_resos):
- reso_to_index[reso] = i
-
- def split_to_buckets(is_reg, img_path_captions):
- for image_path, caption in img_path_captions:
- reso, _ = self.size_lat_cache[image_path]
- bucket_index = reso_to_index[reso]
- self.buckets[bucket_index].append((is_reg, image_path, caption))
-
- split_to_buckets(False, self.train_img_path_captions)
-
- if self.enable_reg_images:
- l = []
- while len(l) < len(self.train_img_path_captions):
- l += self.reg_img_path_captions
- l = l[:len(self.train_img_path_captions)]
- split_to_buckets(True, l)
-
- if enable_bucket:
- print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
- for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
- print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
- img_ar_errors = np.array(img_ar_errors)
- print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
-
- # 参照用indexを作る
- self.buckets_indices = []
- for bucket_index, bucket in enumerate(self.buckets):
- batch_count = int(math.ceil(len(bucket) / self.batch_size))
- for batch_index in range(batch_count):
- self.buckets_indices.append((bucket_index, batch_index))
-
- self.shuffle_buckets()
- self._length = len(self.buckets_indices)
-
- # どのサイズにリサイズするか→トリミングする方向で
- def resize_and_trim(self, image, reso):
- image_height, image_width = image.shape[0:2]
- ar_img = image_width / image_height
- ar_reso = reso[0] / reso[1]
- if ar_img > ar_reso: # 横が長い→縦を合わせる
- scale = reso[1] / image_height
- else:
- scale = reso[0] / image_width
- resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
-
- image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
- if resized_size[0] > reso[0]:
- trim_size = resized_size[0] - reso[0]
- image = image[:, trim_size//2:trim_size//2 + reso[0]]
- elif resized_size[1] > reso[1]:
- trim_size = resized_size[1] - reso[1]
- image = image[trim_size//2:trim_size//2 + reso[1]]
- assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
- f"internal error, illegal trimmed size: {image.shape}, {reso}"
- return image
-
- def shuffle_buckets(self):
- random.shuffle(self.buckets_indices)
- for bucket in self.buckets:
- random.shuffle(bucket)
-
- def load_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- img = np.array(image, np.uint8)
-
- face_cx = face_cy = face_w = face_h = 0
- if self.face_crop_aug_range is not None:
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
- if len(tokens) >= 5:
- face_cx = int(tokens[-4])
- face_cy = int(tokens[-3])
- face_w = int(tokens[-2])
- face_h = int(tokens[-1])
-
- return img, face_cx, face_cy, face_w, face_h
-
- # いい感じに切り出す
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
- height, width = image.shape[0:2]
- if height == self.height and width == self.width:
- return image
-
- # 画像サイズはsizeより大きいのでリサイズする
- face_size = max(face_w, face_h)
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
- if min_scale >= max_scale: # range指定がmin==max
- scale = min_scale
- else:
- scale = random.uniform(min_scale, max_scale)
-
- nh = int(height * scale + .5)
- nw = int(width * scale + .5)
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
- face_cx = int(face_cx * scale + .5)
- face_cy = int(face_cy * scale + .5)
- height, width = nh, nw
-
- # 顔を中心として448*640とかへを切り出す
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
-
- if self.random_crop:
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
- else:
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
- if face_size > self.size // 10 and face_size >= 40:
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
-
- p1 = max(0, min(p1, length - target_size))
-
- if axis == 0:
- image = image[p1:p1 + target_size, :]
- else:
- image = image[:, p1:p1 + target_size]
-
- return image
-
- def __len__(self):
- return self._length
-
- def __getitem__(self, index):
- if index == 0:
- self.shuffle_buckets()
-
- bucket = self.buckets[self.buckets_indices[index][0]]
- image_index = self.buckets_indices[index][1] * self.batch_size
-
- latents_list = []
- images = []
- captions = []
- loss_weights = []
-
- for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
- loss_weights.append(self.prior_loss_weight if is_reg else 1.0)
-
- # image/latentsを処理する
- reso, latents = self.size_lat_cache[image_path]
-
- if latents is None:
- # 画像を読み込み必要ならcropする
- img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
- im_h, im_w = img.shape[0:2]
-
- if self.enable_bucket:
- img = self.resize_and_trim(img, reso)
- else:
- if face_cx > 0: # 顔位置情報あり
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
- elif im_h > self.height or im_w > self.width:
- assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください"
- if im_h > self.height:
- p = random.randint(0, im_h - self.height)
- img = img[p:p + self.height]
- if im_w > self.width:
- p = random.randint(0, im_w - self.width)
- img = img[:, p:p + self.width]
-
- im_h, im_w = img.shape[0:2]
- assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
-
- # augmentation
- if self.aug is not None:
- img = self.aug(image=img)['image']
-
- image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
- else:
- image = None
-
- images.append(image)
- latents_list.append(latents)
-
- # captionを処理する
- if self.shuffle_caption: # captionのshuffleをする
- tokens = caption.strip().split(",")
- random.shuffle(tokens)
- caption = ",".join(tokens).strip()
- captions.append(caption)
-
- # input_idsをpadしてTensor変換
- if self.disable_padding:
- # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
- input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
- else:
- # paddingする
- input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
-
- example = {}
- example['loss_weights'] = torch.FloatTensor(loss_weights)
- example['input_ids'] = input_ids
- if images[0] is not None:
- images = torch.stack(images)
- images = images.to(memory_format=torch.contiguous_format).float()
- else:
- images = None
- example['images'] = images
- example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
- if self.debug_dataset:
- example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
- example['captions'] = captions
- return example
-# endregion
-
-
-# region モジュール入れ替え部
-"""
-高速化のためのモジュール入れ替え
-"""
-
-# FlashAttentionを使うCrossAttention
-# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
-# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
-
-# constants
-
-EPSILON = 1e-6
-
-# helper functions
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-# flash attention forwards and backwards
-
-# https://arxiv.org/abs/2205.14135
-
-
-class FlashAttentionFunction(Function):
- @ staticmethod
- @ torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
-
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
-
- scale = (q.shape[-1] ** -0.5)
-
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim=-1)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- mask,
- all_row_sums.split(q_bucket_size, dim=-2),
- all_row_maxes.split(q_bucket_size, dim=-2),
- )
-
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
-
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
-
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
-
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
-
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
-
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
-
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
-
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
-
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
-
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
-
- return o
-
- @ staticmethod
- @ torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
-
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
-
- device = q.device
-
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- do.split(q_bucket_size, dim=-2),
- mask,
- l.split(q_bucket_size, dim=-2),
- m.split(q_bucket_size, dim=-2),
- dq.split(q_bucket_size, dim=-2)
- )
-
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- dk.split(k_bucket_size, dim=-2),
- dv.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- exp_attn_weights = torch.exp(attn_weights - mc)
-
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
-
- p = exp_attn_weights / lc
-
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
-
- D = (doc * oc).sum(dim=-1, keepdims=True)
- ds = p * scale * (dp - D)
-
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
-
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
-
- return dq, dk, dv, None, None, None, None
-
-
-def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
- if mem_eff_attn:
- replace_unet_cross_attn_to_memory_efficient()
- elif xformers:
- replace_unet_cross_attn_to_xformers()
-
-
-def replace_unet_cross_attn_to_memory_efficient():
- print("Replace CrossAttention.forward to use FlashAttention")
- flash_func = FlashAttentionFunction
-
- def forward_flash_attn(self, x, context=None, mask=None):
- q_bucket_size = 512
- k_bucket_size = 1024
-
- h = self.heads
- q = self.to_q(x)
-
- context = context if context is not None else x
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
-
- out = rearrange(out, 'b h n d -> b n (h d)')
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
-
-
-def replace_unet_cross_attn_to_xformers():
- print("Replace CrossAttention.forward to use xformers")
- try:
- import xformers.ops
- except ImportError:
- raise ImportError("No xformers / xformersがインストールされていないようです")
-
- def forward_xformers(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
-
- context = default(context, x)
- context = context.to(x.dtype)
-
- k_in = self.to_k(context)
- v_in = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) # new format
- # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) # legacy format
- del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
- # out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_xformers
-# endregion
-
-
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# DiffUsers版StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# V2
-V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20]
-V2_UNET_PARAMS_CONTEXT_DIM = 1024
-
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def linear_transformer_to_conv(checkpoint):
- keys = list(checkpoint.keys())
- tf_keys = ["proj_in.weight", "proj_out.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in tf_keys:
- if checkpoint[key].ndim == 2:
- checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2)
-
-
-def convert_ldm_unet_checkpoint(v2, checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- # SDのv2では1*1のconv2dがlinearに変わっているので、linear->convに変換する
- if v2:
- linear_transformer_to_conv(new_checkpoint)
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config(v2):
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint_v1(checkpoint):
- keys = list(checkpoint.keys())
- text_model_dict = {}
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
- return text_model_dict
-
-
-def convert_ldm_clip_checkpoint_v2(checkpoint, max_length):
- # 嫌になるくらい違うぞ!
- def convert_key(key):
- if not key.startswith("cond_stage_model"):
- return None
-
- # common conversion
- key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.")
- key = key.replace("cond_stage_model.model.", "text_model.")
-
- if "resblocks" in key:
- # resblocks conversion
- key = key.replace(".resblocks.", ".layers.")
- if ".ln_" in key:
- key = key.replace(".ln_", ".layer_norm")
- elif ".mlp." in key:
- key = key.replace(".c_fc.", ".fc1.")
- key = key.replace(".c_proj.", ".fc2.")
- elif '.attn.out_proj' in key:
- key = key.replace(".attn.out_proj.", ".self_attn.out_proj.")
- elif '.attn.in_proj' in key:
- key = None # 特殊なので後で処理する
- else:
- raise ValueError(f"unexpected key in SD: {key}")
- elif '.positional_embedding' in key:
- key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight")
- elif '.text_projection' in key:
- key = None # 使われない???
- elif '.logit_scale' in key:
- key = None # 使われない???
- elif '.token_embedding' in key:
- key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight")
- elif '.ln_final' in key:
- key = key.replace(".ln_final", ".final_layer_norm")
- return key
-
- keys = list(checkpoint.keys())
- new_sd = {}
- for key in keys:
- # remove resblocks 23
- if '.resblocks.23.' in key:
- continue
- new_key = convert_key(key)
- if new_key is None:
- continue
- new_sd[new_key] = checkpoint[key]
-
- # attnの変換
- for key in keys:
- if '.resblocks.23.' in key:
- continue
- if '.resblocks' in key and '.attn.in_proj_' in key:
- # 三つに分割
- values = torch.chunk(checkpoint[key], 3)
-
- key_suffix = ".weight" if "weight" in key else ".bias"
- key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.")
- key_pfx = key_pfx.replace("_weight", "")
- key_pfx = key_pfx.replace("_bias", "")
- key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.")
- new_sd[key_pfx + "q_proj" + key_suffix] = values[0]
- new_sd[key_pfx + "k_proj" + key_suffix] = values[1]
- new_sd[key_pfx + "v_proj" + key_suffix] = values[2]
-
- # position_idsの追加
- new_sd["text_model.embeddings.position_ids"] = torch.Tensor([list(range(max_length))]).to(torch.int64)
- return new_sd
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def conv_transformer_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- tf_keys = ["proj_in.weight", "proj_out.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in tf_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
-
-
-def convert_unet_state_dict_to_sd(v2, unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
-
- if v2:
- conv_transformer_to_linear(new_state_dict)
-
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_text_encoder_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(v2, ckpt_path, dtype=None):
- checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
- if dtype is not None:
- for k, v in state_dict.items():
- if type(v) is torch.Tensor:
- state_dict[k] = v.to(dtype)
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config(v2)
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(v2, state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- info = unet.load_state_dict(converted_unet_checkpoint)
- print("loading u-net:", info)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- info = vae.load_state_dict(converted_vae_checkpoint)
- print("loadint vae:", info)
-
- # convert text_model
- if v2:
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v2(state_dict, 77)
- cfg = CLIPTextConfig(
- vocab_size=49408,
- hidden_size=1024,
- intermediate_size=4096,
- num_hidden_layers=23,
- num_attention_heads=16,
- max_position_embeddings=77,
- hidden_act="gelu",
- layer_norm_eps=1e-05,
- dropout=0.0,
- attention_dropout=0.0,
- initializer_range=0.02,
- initializer_factor=1.0,
- pad_token_id=1,
- bos_token_id=0,
- eos_token_id=2,
- model_type="clip_text_model",
- projection_dim=512,
- torch_dtype="float32",
- transformers_version="4.25.0.dev0",
- )
- text_model = CLIPTextModel._from_config(cfg)
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
- else:
- converted_text_encoder_checkpoint = convert_ldm_clip_checkpoint_v1(state_dict)
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
- info = text_model.load_state_dict(converted_text_encoder_checkpoint)
- print("loading text encoder:", info)
-
- return text_model, vae, unet
-
-
-def convert_text_encoder_state_dict_to_sd_v2(checkpoint):
- def convert_key(key):
- # position_idsの除去
- if ".position_ids" in key:
- return None
-
- # common
- key = key.replace("text_model.encoder.", "transformer.")
- key = key.replace("text_model.", "")
- if "layers" in key:
- # resblocks conversion
- key = key.replace(".layers.", ".resblocks.")
- if ".layer_norm" in key:
- key = key.replace(".layer_norm", ".ln_")
- elif ".mlp." in key:
- key = key.replace(".fc1.", ".c_fc.")
- key = key.replace(".fc2.", ".c_proj.")
- elif '.self_attn.out_proj' in key:
- key = key.replace(".self_attn.out_proj.", ".attn.out_proj.")
- elif '.self_attn.' in key:
- key = None # 特殊なので後で処理する
- else:
- raise ValueError(f"unexpected key in DiffUsers model: {key}")
- elif '.position_embedding' in key:
- key = key.replace("embeddings.position_embedding.weight", "positional_embedding")
- elif '.token_embedding' in key:
- key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight")
- elif 'final_layer_norm' in key:
- key = key.replace("final_layer_norm", "ln_final")
- return key
-
- keys = list(checkpoint.keys())
- new_sd = {}
- for key in keys:
- new_key = convert_key(key)
- if new_key is None:
- continue
- new_sd[new_key] = checkpoint[key]
-
- # attnの変換
- for key in keys:
- if 'layers' in key and 'q_proj' in key:
- # 三つを結合
- key_q = key
- key_k = key.replace("q_proj", "k_proj")
- key_v = key.replace("q_proj", "v_proj")
-
- value_q = checkpoint[key_q]
- value_k = checkpoint[key_k]
- value_v = checkpoint[key_v]
- value = torch.cat([value_q, value_k, value_v])
-
- new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.")
- new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_")
- new_sd[new_key] = value
-
- return new_sd
-
-
-def save_stable_diffusion_checkpoint(v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_text_encoder_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- def assign_new_sd(prefix, sd):
- for k, v in sd.items():
- key = prefix + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict())
- assign_new_sd("model.diffusion_model.", unet_state_dict)
-
- # Convert the text encoder model
- if v2:
- text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict())
- assign_new_sd("cond_stage_model.model.", text_enc_dict)
- else:
- text_enc_dict = text_encoder.state_dict()
- assign_new_sd("cond_stage_model.transformer.", text_enc_dict)
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-
-
-def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, save_dtype):
- pipeline = StableDiffusionPipeline(
- unet=unet,
- text_encoder=text_encoder,
- vae=AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae"),
- scheduler=DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler"),
- tokenizer=CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer"),
- safety_checker=None,
- feature_extractor=None,
- requires_safety_checker=None,
- )
- pipeline.save_pretrained(output_dir)
-
-# endregion
-
-
-def collate_fn(examples):
- return examples[0]
-
-
-def train(args):
- if args.caption_extention is not None:
- args.caption_extension = args.caption_extention
- args.caption_extention = None
-
- fine_tuning = args.fine_tuning
- cache_latents = args.cache_latents
-
- # latentsをキャッシュする場合のオプション設定を確認する
- if cache_latents:
- assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません"
-
- # その他のオプション設定を確認する
- if args.v_parameterization and not args.v2:
- print("v_parameterization should be with v2 / v1でv_parameterizationを使用することは想定されていません")
- if args.v2 and args.clip_skip is not None:
- print("v2 with clip_skip will be unexpected / v2でclip_skipを使用することは想定されていません")
-
- # モデル形式のオプション設定を確認する:
- # v11からDiffUsersから直接落としてくるのもOK(ただし認証がいるやつは未対応)、またv11からDiffUsersも途中保存に対応した
- use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
-
- # 乱数系列を初期化する
- if args.seed is not None:
- set_seed(args.seed)
-
- # 学習データを用意する
- def read_caption(img_path):
- # captionの候補ファイル名を作る
- base_name = os.path.splitext(img_path)[0]
- base_name_face_det = base_name
- tokens = base_name.split("_")
- if len(tokens) >= 5:
- base_name_face_det = "_".join(tokens[:-4])
- cap_paths = [base_name + args.caption_extension, base_name_face_det + args.caption_extension]
-
- caption = None
- for cap_path in cap_paths:
- if os.path.isfile(cap_path):
- with open(cap_path, "rt", encoding='utf-8') as f:
- caption = f.readlines()[0].strip()
- break
- return caption
-
- def load_dreambooth_dir(dir):
- tokens = os.path.basename(dir).split('_')
- try:
- n_repeats = int(tokens[0])
- except ValueError as e:
- return 0, []
-
- caption_by_folder = '_'.join(tokens[1:])
-
- print(f"found directory {n_repeats}_{caption_by_folder}")
-
- img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg")) + \
- glob.glob(os.path.join(dir, "*.webp"))
-
- # 画像ファイルごとにプロンプトを読み込み、もしあればそちらを使う(v11から仕様変更した)
- captions = []
- for img_path in img_paths:
- cap_for_img = read_caption(img_path)
- captions.append(caption_by_folder if cap_for_img is None else cap_for_img)
-
- return n_repeats, list(zip(img_paths, captions))
-
- print("prepare train images.")
- train_img_path_captions = []
-
- if fine_tuning:
- img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + \
- glob.glob(os.path.join(args.train_data_dir, "*.jpg")) + glob.glob(os.path.join(args.train_data_dir, "*.webp"))
- for img_path in tqdm(img_paths):
- caption = read_caption(img_path)
- assert caption is not None and len(
- caption) > 0, f"no caption for image. check caption_extension option / キャプションファイルが見つからないかcaptionが空です。caption_extensionオプションを確認してください: {img_path}"
-
- train_img_path_captions.append((img_path, caption))
-
- if args.dataset_repeats is not None:
- l = []
- for _ in range(args.dataset_repeats):
- l.extend(train_img_path_captions)
- train_img_path_captions = l
- else:
- train_dirs = os.listdir(args.train_data_dir)
- for dir in train_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
- for _ in range(n_repeats):
- train_img_path_captions.extend(img_caps)
- print(f"{len(train_img_path_captions)} train images with repeating.")
-
- reg_img_path_captions = []
- if args.reg_data_dir:
- print("prepare reg images.")
- reg_dirs = os.listdir(args.reg_data_dir)
- for dir in reg_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir))
- for _ in range(n_repeats):
- reg_img_path_captions.extend(img_caps)
- print(f"{len(reg_img_path_captions)} reg images.")
-
- # データセットを準備する
- resolution = tuple([int(r) for r in args.resolution.split(',')])
- if len(resolution) == 1:
- resolution = (resolution[0], resolution[0])
- assert len(resolution) == 2, \
- f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
-
- if args.enable_bucket:
- assert min(resolution) >= args.min_bucket_reso, f"min_bucket_reso must be equal or greater than resolution / min_bucket_resoは解像度の数値以上で指定してください"
- assert max(resolution) <= args.max_bucket_reso, f"max_bucket_reso must be equal or less than resolution / max_bucket_resoは解像度の数値以下で指定してください"
-
- if args.face_crop_aug_range is not None:
- face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
- assert len(
- face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
- else:
- face_crop_aug_range = None
-
- # tokenizerを読み込む
- print("prepare tokenizer")
- if args.v2:
- tokenizer = CLIPTokenizer.from_pretrained(V2_STABLE_DIFFUSION_PATH, subfolder="tokenizer")
- else:
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
-
- print("prepare dataset")
- train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
- args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
- args.shuffle_caption, args.no_token_padding, args.debug_dataset)
-
- if args.debug_dataset:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso,
- args.max_bucket_reso) # デバッグ用にcacheなしで作る
- print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
- print("Escape for exit. / Escキーで中断、終了します")
- for example in train_dataset:
- for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
- print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
- cv2.imshow("img", im)
- k = cv2.waitKey()
- cv2.destroyAllWindows()
- if k == 27:
- break
- if k == 27:
- break
- return
-
- # acceleratorを準備する
- # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
- print("prepare accelerator")
- if args.logging_dir is None:
- log_with = None
- logging_dir = None
- else:
- log_with = "tensorboard"
- logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
- accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
- log_with=log_with, logging_dir=logging_dir)
-
- # mixed precisionに対応した型を用意しておき適宜castする
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- save_dtype = None
- if args.save_precision == "fp16":
- save_dtype = torch.float16
- elif args.save_precision == "bf16":
- save_dtype = torch.bfloat16
- elif args.save_precision == "float":
- save_dtype = torch.float32
-
- # モデルを読み込む
- if use_stable_diffusion_format:
- print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.v2, args.pretrained_model_name_or_path)
- else:
- print("load Diffusers pretrained models")
- pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, tokenizer=None, safety_checker=None)
- # , torch_dtype=weight_dtype) ここでtorch_dtypeを指定すると学習時にエラーになる
- text_encoder = pipe.text_encoder
- vae = pipe.vae
- unet = pipe.unet
- del pipe
-
- # モデルに xformers とか memory efficient attention を組み込む
- replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
-
- # 学習を準備する
- if cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
- vae.requires_grad_(False)
- vae.eval()
- with torch.no_grad():
- train_dataset.make_buckets_with_caching(args.enable_bucket, vae, args.min_bucket_reso, args.max_bucket_reso)
- del vae
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- else:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None, args.min_bucket_reso, args.max_bucket_reso)
- vae.requires_grad_(False)
- vae.eval()
-
- unet.requires_grad_(True) # 念のため追加
- text_encoder.requires_grad_(True)
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- text_encoder.gradient_checkpointing_enable()
-
- # 学習に必要なクラスを準備する
- print("prepare optimizer, data loader etc.")
-
- # 8-bit Adamを使う
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
- print("use 8-bit Adam optimizer")
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
-
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
-
- # dataloaderを準備する
- # DataLoaderのプロセス数:0はメインプロセスになる
- n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
-
- # lr schedulerを用意する
- lr_scheduler = diffusers.optimization.get_scheduler(
- args.lr_scheduler, optimizer, num_warmup_steps=args.lr_warmup_steps, num_training_steps=args.max_train_steps)
-
- # acceleratorがなんかよろしくやってくれるらしい
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
-
- if not cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # resumeする
- if args.resume is not None:
- print(f"resume training from state: {args.resume}")
- accelerator.load_state(args.resume)
-
- # epoch数を計算する
- num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
-
- # 学習する
- total_batch_size = args.train_batch_size # * accelerator.num_processes
- print("running training / 学習開始")
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
- print(f" num examples / サンプル数: {train_dataset.num_train_images * (2 if train_dataset.enable_reg_images else 1)}")
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
- print(f" num epochs / epoch数: {num_train_epochs}")
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
- print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
-
- progress_bar = tqdm(range(args.max_train_steps), smoothing=0, disable=not accelerator.is_local_main_process, desc="steps")
- global_step = 0
-
- # v12で更新:clip_sample=Falseに
- # Diffusersのtrain_dreambooth.pyがconfigから持ってくるように変更されたので、clip_sample=Falseになるため、それに合わせる
- # 既存の1.4/1.5/2.0はすべてschdulerのconfigは(クラス名を除いて)同じ
- # よくソースを見たら学習時は関係ないや(;'∀')
- noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
- num_train_timesteps=1000, clip_sample=False)
-
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
-
- # 以下 train_dreambooth.py からほぼコピペ
- for epoch in range(num_train_epochs):
- print(f"epoch {epoch+1}/{num_train_epochs}")
- unet.train()
- text_encoder.train()
-
- loss_total = 0
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet):
- with torch.no_grad():
- # latentに変換
- if cache_latents:
- latents = batch["latents"].to(accelerator.device)
- else:
- latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents, device=latents.device)
- b_size = latents.shape[0]
-
- # Sample a random timestep for each image
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # 指定したステップ数までText Encoderを学習する
- train_text_encoder = args.stop_text_encoder_training is None or global_step < args.stop_text_encoder_training
- with torch.set_grad_enabled(train_text_encoder):
- # Get the text embedding for conditioning
- if args.clip_skip is None:
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
- else:
- enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
-
- # Predict the noise residual
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- if args.v_parameterization:
- # v-parameterization training
- # こうしたい:
- # target = noise_scheduler.get_v(latents, noise, timesteps)
-
- # StabilityAiのddpm.pyのコード:
- # elif self.parameterization == "v":
- # target = self.get_v(x_start, noise, t)
- # ...
- # def get_v(self, x, noise, t):
- # return (
- # extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
- # extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
- # )
-
- # scheduling_ddim.pyのコード:
- # elif self.config.prediction_type == "v_prediction":
- # pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
- # # predict V
- # model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
-
- # これでいいかな?:
- alpha_prod_t = noise_scheduler.alphas_cumprod[timesteps]
- beta_prod_t = 1 - alpha_prod_t
- alpha_prod_t = torch.reshape(alpha_prod_t, (len(alpha_prod_t), 1, 1, 1)) # broadcastされないらしいのでreshape
- beta_prod_t = torch.reshape(beta_prod_t, (len(beta_prod_t), 1, 1, 1))
- target = (alpha_prod_t ** 0.5) * noise - (beta_prod_t ** 0.5) * latents
- else:
- target = noise
-
- loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
- loss = loss.mean([1, 2, 3])
-
- loss_weights = batch["loss_weights"] # 各sampleごとのweight
- loss = loss * loss_weights
-
- loss = loss.mean() # 平均なのでbatch_sizeで割る必要なし
-
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=True)
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- if global_step == args.stop_text_encoder_training:
- print(f"stop text encoder training at step {global_step}")
-
- current_loss = loss.detach().item()
- if args.logging_dir is not None:
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
- accelerator.log(logs, step=global_step)
-
- loss_total += current_loss
- avr_loss = loss_total / (step+1)
- logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
- if args.logging_dir is not None:
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
- accelerator.log(logs, step=epoch+1)
-
- accelerator.wait_for_everyone()
-
- if args.save_every_n_epochs is not None:
- if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
- print("saving checkpoint.")
- if use_stable_diffusion_format:
- os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(args.v2, ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
- else:
- out_dir = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(epoch + 1))
- os.makedirs(out_dir, exist_ok=True)
- save_diffusers_checkpoint(args.v2, out_dir, accelerator.unwrap_model(text_encoder),
- accelerator.unwrap_model(unet), args.pretrained_model_name_or_path, save_dtype)
-
- if args.save_state:
- print("saving state.")
- accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
-
- is_main_process = accelerator.is_main_process
- if is_main_process:
- unet = accelerator.unwrap_model(unet)
- text_encoder = accelerator.unwrap_model(text_encoder)
-
- accelerator.end_training()
-
- if args.save_state:
- print("saving last state.")
- accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
-
- del accelerator # この後メモリを使うのでこれは消す
-
- if is_main_process:
- os.makedirs(args.output_dir, exist_ok=True)
- if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(args.v2, ckpt_file, text_encoder, unet,
- args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
- else:
- # Create the pipeline using using the trained modules and save it.
- print(f"save trained model as Diffusers to {args.output_dir}")
- out_dir = os.path.join(args.output_dir, LAST_DIFFUSERS_DIR_NAME)
- os.makedirs(out_dir, exist_ok=True)
- save_diffusers_checkpoint(args.v2, out_dir, text_encoder, unet, args.pretrained_model_name_or_path, save_dtype)
- print("model saved.")
-
-
-if __name__ == '__main__':
- # torch.cuda.set_per_process_memory_fraction(0.48)
- parser = argparse.ArgumentParser()
- parser.add_argument("--v2", action='store_true',
- help='load Stable Diffusion v2.0 model / Stable Diffusion 2.0のモデルを読み込む')
- parser.add_argument("--v_parameterization", action='store_true',
- help='enable v-parameterization training / v-parameterization学習を有効にする')
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
- parser.add_argument("--fine_tuning", action="store_true",
- help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
- parser.add_argument("--shuffle_caption", action="store_true",
- help="shuffle comma-separated caption / コンマで区切られたcaptionの各要素をshuffleする")
- parser.add_argument("--caption_extention", type=str, default=None,
- help="extension of caption files (backward compatiblity) / 読み込むcaptionファイルの拡張子(スペルミスを残してあります)")
- parser.add_argument("--caption_extension", type=str, default=".caption", help="extension of caption files / 読み込むcaptionファイルの拡張子")
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
- parser.add_argument("--dataset_repeats", type=int, default=None,
- help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
- parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model (default format is same to input) / 学習後のモデル出力先ディレクトリ(デフォルトの保存形式は読み込んだ形式と同じ)")
- # parser.add_argument("--save_as_sd", action='store_true',
- # help="save the model as StableDiffusion checkpoint / 学習後のモデルをStableDiffusionのcheckpointとして保存する")
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
- help="save checkpoint every N epochs / 学習中のモデルを指定エポックごとに保存します")
- parser.add_argument("--save_state", action="store_true",
- help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
- parser.add_argument("--resume", type=str, default=None, help="saved state to resume training / 学習再開するモデルのstate")
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
- parser.add_argument("--no_token_padding", action="store_true",
- help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--stop_text_encoder_training", type=int, default=None, help="steps to stop text encoder training / Text Encoderの学習を止めるステップ数")
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
- parser.add_argument("--random_crop", action="store_true",
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
- parser.add_argument("--debug_dataset", action="store_true",
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
- parser.add_argument("--resolution", type=str, default=None,
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
- parser.add_argument("--train_batch_size", type=int, default=1,
- help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)")
- parser.add_argument("--use_8bit_adam", action="store_true",
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
- parser.add_argument("--mem_eff_attn", action="store_true",
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
- parser.add_argument("--xformers", action="store_true",
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
- parser.add_argument("--cache_latents", action="store_true",
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
- parser.add_argument("--enable_bucket", action="store_true",
- help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
- parser.add_argument("--min_bucket_reso", type=int, default=256, help="minimum resolution for buckets / bucketの最小解像度")
- parser.add_argument("--max_bucket_reso", type=int, default=1024, help="maximum resolution for buckets / bucketの最小解像度")
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
- parser.add_argument("--gradient_checkpointing", action="store_true",
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
- parser.add_argument("--mixed_precision", type=str, default="no",
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
- parser.add_argument("--save_precision", type=str, default=None,
- choices=[None, "float", "fp16", "bf16"], help="precision in saving (available in StableDiffusion checkpoint) / 保存時に精度を変更して保存する(StableDiffusion形式での保存時のみ有効)")
- parser.add_argument("--clip_skip", type=int, default=None,
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
- parser.add_argument("--logging_dir", type=str, default=None,
- help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
- parser.add_argument("--lr_scheduler", type=str, default="constant",
- help="scheduler to use for learning rate / 学習率のスケジューラ: linear, cosine, cosine_with_restarts, polynomial, constant (default), constant_with_warmup")
- parser.add_argument("--lr_warmup_steps", type=int, default=0,
- help="Number of steps for the warmup in the lr scheduler (default is 0) / 学習率のスケジューラをウォームアップするステップ数(デフォルト0)")
-
- args = parser.parse_args()
- train(args)
diff --git a/train_db_fixed/train_db_fixed_v7.py b/train_db_fixed/train_db_fixed_v7.py
deleted file mode 100644
index b19b6abc..00000000
--- a/train_db_fixed/train_db_fixed_v7.py
+++ /dev/null
@@ -1,1609 +0,0 @@
-# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
-# (c) 2022 Kohya S. @kohya_ss
-
-# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
-# enable reg images in fine-tuning, add dataset_repeats option
-
-from torch.autograd.function import Function
-import argparse
-import glob
-import itertools
-import math
-import os
-import random
-
-from tqdm import tqdm
-import torch
-from torchvision import transforms
-from accelerate import Accelerator
-from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-import albumentations as albu
-import numpy as np
-from PIL import Image
-import cv2
-from einops import rearrange
-from torch import einsum
-
-# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
-TOKENIZER_PATH = "openai/clip-vit-large-patch14"
-
-# StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
-
-
-class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
- def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
- super().__init__()
-
- self.fine_tuning = fine_tuning
- self.train_img_path_captions = train_img_path_captions
- self.reg_img_path_captions = reg_img_path_captions
- self.tokenizer = tokenizer
- self.width, self.height = resolution
- self.size = min(self.width, self.height) # 短いほう
- self.prior_loss_weight = prior_loss_weight
- self.face_crop_aug_range = face_crop_aug_range
- self.random_crop = random_crop
- self.debug_dataset = debug_dataset
- self.shuffle_caption = shuffle_caption
- self.disable_padding = disable_padding
- self.latents_cache = None
-
- # augmentation
- flip_p = 0.5 if flip_aug else 0.0
- if color_aug:
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る
- self.aug = albu.Compose([
- albu.OneOf([
- # albu.RandomBrightnessContrast(0.05, 0.05, p=.2),
- albu.HueSaturationValue(5, 8, 0, p=.2),
- # albu.RGBShift(5, 5, 5, p=.1),
- albu.RandomGamma((95, 105), p=.5),
- ], p=.33),
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- elif flip_aug:
- self.aug = albu.Compose([
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- else:
- self.aug = None
-
- self.num_train_images = len(self.train_img_path_captions)
- self.num_reg_images = len(self.reg_img_path_captions)
-
- self.enable_reg_images = self.num_reg_images > 0
-
- if not self.enable_reg_images:
- self._length = self.num_train_images
- else:
- # 学習データの倍として、奇数ならtrain
- self._length = self.num_train_images * 2
- if self._length // 2 < self.num_reg_images:
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
-
- self.image_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- def load_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- img = np.array(image, np.uint8)
-
- face_cx = face_cy = face_w = face_h = 0
- if self.face_crop_aug_range is not None:
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
- if len(tokens) >= 5:
- face_cx = int(tokens[-4])
- face_cy = int(tokens[-3])
- face_w = int(tokens[-2])
- face_h = int(tokens[-1])
-
- return img, face_cx, face_cy, face_w, face_h
-
- # いい感じに切り出す
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
- height, width = image.shape[0:2]
- if height == self.height and width == self.width:
- return image
-
- # 画像サイズはsizeより大きいのでリサイズする
- face_size = max(face_w, face_h)
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
- if min_scale >= max_scale: # range指定がmin==max
- scale = min_scale
- else:
- scale = random.uniform(min_scale, max_scale)
-
- nh = int(height * scale + .5)
- nw = int(width * scale + .5)
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
- face_cx = int(face_cx * scale + .5)
- face_cy = int(face_cy * scale + .5)
- height, width = nh, nw
-
- # 顔を中心として448*640とかへを切り出す
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
-
- if self.random_crop:
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
- else:
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
- if face_size > self.size // 10 and face_size >= 40:
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
-
- p1 = max(0, min(p1, length - target_size))
-
- if axis == 0:
- image = image[p1:p1 + target_size, :]
- else:
- image = image[:, p1:p1 + target_size]
-
- return image
-
- def __len__(self):
- return self._length
-
- def set_cached_latents(self, image_path, latents):
- if self.latents_cache is None:
- self.latents_cache = {}
- self.latents_cache[image_path] = latents
-
- def __getitem__(self, index_arg):
- example = {}
-
- if not self.enable_reg_images:
- index = index_arg
- img_path_captions = self.train_img_path_captions
- reg = False
- else:
- # 偶数ならtrain、奇数ならregを返す
- if index_arg % 2 == 0:
- img_path_captions = self.train_img_path_captions
- reg = False
- else:
- img_path_captions = self.reg_img_path_captions
- reg = True
- index = index_arg // 2
- example['loss_weight'] = 1.0 if (not reg or self.fine_tuning) else self.prior_loss_weight
-
- index = index % len(img_path_captions)
- image_path, caption = img_path_captions[index]
- example['image_path'] = image_path
-
- # image/latentsを処理する
- if self.latents_cache is not None and image_path in self.latents_cache:
- # latentsはキャッシュ済み
- example['latents'] = self.latents_cache[image_path]
- else:
- # 画像を読み込み必要ならcropする
- img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
- im_h, im_w = img.shape[0:2]
- if face_cx > 0: # 顔位置情報あり
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
- elif im_h > self.height or im_w > self.width:
- assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください"
- if im_h > self.height:
- p = random.randint(0, im_h - self.height)
- img = img[p:p + self.height]
- if im_w > self.width:
- p = random.randint(0, im_w - self.width)
- img = img[:, p:p + self.width]
-
- im_h, im_w = img.shape[0:2]
- assert im_h == self.height and im_w == self.width, f"image too small / 画像サイズが小さいようです: {image_path}"
-
- # augmentation
- if self.aug is not None:
- img = self.aug(image=img)['image']
-
- example['image'] = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
-
- # captionを処理する
- if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
- tokens = caption.strip().split(",")
- random.shuffle(tokens)
- caption = ",".join(tokens).strip()
-
- input_ids = self.tokenizer(caption, padding="do_not_pad", truncation=True,
- max_length=self.tokenizer.model_max_length).input_ids
-
- # padしてTensor変換
- if self.disable_padding:
- # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
- input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
- else:
- # paddingする
- input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding='max_length', max_length=self.tokenizer.model_max_length,
- return_tensors='pt').input_ids
-
- example['input_ids'] = input_ids
-
- if self.debug_dataset:
- example['caption'] = caption
- return example
-
-
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def convert_ldm_unet_checkpoint(checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint(checkpoint):
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
-
- keys = list(checkpoint.keys())
-
- text_model_dict = {}
-
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
-
- text_model.load_state_dict(text_model_dict)
-
- return text_model
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def convert_unet_state_dict(unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(ckpt_path):
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config()
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- unet.load_state_dict(converted_unet_checkpoint)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- vae.load_state_dict(converted_vae_checkpoint)
-
- # convert text_model
- text_model = convert_ldm_clip_checkpoint(state_dict)
-
- return text_model, vae, unet
-
-
-def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict(unet.state_dict())
- for k, v in unet_state_dict.items():
- key = "model.diffusion_model." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- state_dict[key] = v
-
- # Convert the text encoder model
- text_enc_dict = text_encoder.state_dict() # 変換不要
- for k, v in text_enc_dict.items():
- key = "cond_stage_model.transformer." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- state_dict[key] = v
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-# endregion
-
-
-def collate_fn(examples):
- input_ids = [e['input_ids'] for e in examples]
- input_ids = torch.stack(input_ids)
-
- if 'latents' in examples[0]:
- pixel_values = None
- latents = [e['latents'] for e in examples]
- latents = torch.stack(latents)
- else:
- pixel_values = [e['image'] for e in examples]
- pixel_values = torch.stack(pixel_values)
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
- latents = None
-
- loss_weights = [e['loss_weight'] for e in examples]
- loss_weights = torch.FloatTensor(loss_weights)
-
- batch = {"input_ids": input_ids, "pixel_values": pixel_values, "latents": latents, "loss_weights": loss_weights}
- return batch
-
-
-def train(args):
- fine_tuning = args.fine_tuning
- cache_latents = args.cache_latents
-
- # latentsをキャッシュする場合のオプション設定を確認する
- if cache_latents:
- # assert args.face_crop_aug_range is None and not args.random_crop, "when caching latents, crop aug cannot be used / latentをキャッシュするときは切り出しは使えません"
- # →使えるようにしておく(初期イメージの切り出しになる)
- assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません"
-
- # モデル形式のオプション設定を確認する
- use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
- if not use_stable_diffusion_format:
- assert os.path.exists(
- args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
-
- assert args.save_every_n_epochs is None or use_stable_diffusion_format, "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
-
- if args.seed is not None:
- set_seed(args.seed)
-
- # 学習データを用意する
- def load_dreambooth_dir(dir):
- tokens = os.path.basename(dir).split('_')
- try:
- n_repeats = int(tokens[0])
- except ValueError as e:
- print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
- raise e
-
- caption = '_'.join(tokens[1:])
-
- img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg"))
- return n_repeats, [(ip, caption) for ip in img_paths]
-
- print("prepare train images.")
- train_img_path_captions = []
-
- if fine_tuning:
- img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg"))
- for img_path in tqdm(img_paths):
- # captionの候補ファイル名を作る
- base_name = os.path.splitext(img_path)[0]
- base_name_face_det = base_name
- tokens = base_name.split("_")
- if len(tokens) >= 5:
- base_name_face_det = "_".join(tokens[:-4])
- cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption']
-
- caption = None
- for cap_path in cap_paths:
- if os.path.isfile(cap_path):
- with open(cap_path, "rt", encoding='utf-8') as f:
- caption = f.readlines()[0].strip()
- break
-
- assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}"
-
- train_img_path_captions.append((img_path, caption))
-
- if args.dataset_repeats is not None:
- l = []
- for _ in range(args.dataset_repeats):
- l.extend(train_img_path_captions)
- train_img_path_captions = l
- else:
- train_dirs = os.listdir(args.train_data_dir)
- for dir in train_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
- for _ in range(n_repeats):
- train_img_path_captions.extend(img_caps)
- print(f"{len(train_img_path_captions)} train images.")
-
- reg_img_path_captions = []
- if args.reg_data_dir:
- print("prepare reg images.")
- reg_dirs = os.listdir(args.reg_data_dir)
- for dir in reg_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir))
- for _ in range(n_repeats):
- reg_img_path_captions.extend(img_caps)
- print(f"{len(reg_img_path_captions)} reg images.")
-
- if args.debug_dataset:
- # デバッグ時はshuffleして実際のデータセット使用時に近づける(学習時はdata loaderでshuffleする)
- random.shuffle(train_img_path_captions)
- random.shuffle(reg_img_path_captions)
-
- # データセットを準備する
- resolution = tuple([int(r) for r in args.resolution.split(',')])
- if len(resolution) == 1:
- resolution = (resolution[0], resolution[0])
- assert len(
- resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
-
- if args.face_crop_aug_range is not None:
- face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
- assert len(
- face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
- else:
- face_crop_aug_range = None
-
- # tokenizerを読み込む
- print("prepare tokenizer")
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
-
- print("prepare dataset")
- train_dataset = DreamBoothOrFineTuningDataset(fine_tuning, train_img_path_captions,
- reg_img_path_captions, tokenizer, resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.shuffle_caption, args.no_token_padding, args.debug_dataset)
-
- if args.debug_dataset:
- print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
- print("Escape for exit. / Escキーで中断、終了します")
- for example in train_dataset:
- im = example['image']
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
- print(f'caption: "{example["caption"]}", loss weight: {example["loss_weight"]}')
- cv2.imshow("img", im)
- k = cv2.waitKey()
- cv2.destroyAllWindows()
- if k == 27:
- break
- return
-
- # acceleratorを準備する
- # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
- print("prepare accelerator")
- accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision)
-
- # モデルを読み込む
- if use_stable_diffusion_format:
- print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path)
- else:
- print("load Diffusers pretrained models")
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
- unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
-
- # モデルに xformers とか memory efficient attention を組み込む
- replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
-
- # mixed precisionに対応した型を用意しておき適宜castする
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- # 学習を準備する
- if cache_latents:
- # latentをcacheする→新しいDatasetを作るとcaptionのshuffleが効かないので元のDatasetにcacheを持つ(cascadeする手もあるが)
- print("caching latents.")
- vae.to(accelerator.device, dtype=weight_dtype)
-
- for i in tqdm(range(len(train_dataset))):
- example = train_dataset[i]
- if 'latents' not in example:
- image_path = example['image_path']
- with torch.no_grad():
- pixel_values = example["image"].unsqueeze(0).to(device=accelerator.device, dtype=weight_dtype)
- latents = vae.encode(pixel_values).latent_dist.sample().squeeze(0).to("cpu")
- train_dataset.set_cached_latents(image_path, latents)
- # assertion
- for i in range(len(train_dataset)):
- assert 'latents' in train_dataset[i], "internal error: latents not cached"
-
- del vae
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- else:
- vae.requires_grad_(False)
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- text_encoder.gradient_checkpointing_enable()
-
- # 学習に必要なクラスを準備する
- print("prepare optimizer, data loader etc.")
-
- # 8-bit Adamを使う
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
- print("use 8-bit Adam optimizer")
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
-
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
-
- # dataloaderを準備する
- # DataLoaderのプロセス数:0はメインプロセスになる
- n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers)
-
- # lr schedulerを用意する
- lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
-
- # acceleratorがなんかよろしくやってくれるらしい
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
-
- if not cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # epoch数を計算する
- num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
-
- # 学習する
- total_batch_size = args.train_batch_size # * accelerator.num_processes
- print("running training / 学習開始")
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
- print(f" num examples / サンプル数: {len(train_dataset)}")
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
- print(f" num epochs / epoch数: {num_train_epochs}")
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
- print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
-
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
- global_step = 0
-
- noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
-
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
-
- # 以下 train_dreambooth.py からほぼコピペ
- for epoch in range(num_train_epochs):
- print(f"epoch {epoch+1}/{num_train_epochs}")
- unet.train()
- text_encoder.train() # なんかunetだけでいいらしい?→最新版で修正されてた(;´Д`) いろいろ雑だな
-
- loss_total = 0
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet):
- with torch.no_grad():
- # latentに変換
- if cache_latents:
- latents = batch["latents"].to(accelerator.device)
- else:
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents, device=latents.device)
- b_size = latents.shape[0]
-
- # Sample a random timestep for each image
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- if args.clip_skip is None:
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
- else:
- enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
-
- # Predict the noise residual
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="none")
- loss = loss.mean([1, 2, 3])
-
- loss_weights = batch["loss_weights"] # 各sampleごとのweight
- loss = loss * loss_weights
-
- loss = loss.mean()
-
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=True)
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- current_loss = loss.detach().item()
- loss_total += current_loss
- avr_loss = loss_total / (step+1)
- logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- # accelerator.log(logs, step=global_step)
-
- if global_step >= args.max_train_steps:
- break
-
- accelerator.wait_for_everyone()
-
- if use_stable_diffusion_format and args.save_every_n_epochs is not None:
- if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
- print("saving check point.")
- os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step)
-
- is_main_process = accelerator.is_main_process
- if is_main_process:
- unet = accelerator.unwrap_model(unet)
- text_encoder = accelerator.unwrap_model(text_encoder)
-
- accelerator.end_training()
- del accelerator # この後メモリを使うのでこれは消す
-
- if is_main_process:
- os.makedirs(args.output_dir, exist_ok=True)
- if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step)
- else:
- # Create the pipeline using using the trained modules and save it.
- print(f"save trained model as Diffusers to {args.output_dir}")
- pipeline = StableDiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- text_encoder=text_encoder,
- )
- pipeline.save_pretrained(args.output_dir)
- print("model saved.")
-
-
-# region モジュール入れ替え部
-"""
-高速化のためのモジュール入れ替え
-"""
-
-# FlashAttentionを使うCrossAttention
-# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
-# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
-
-# constants
-
-EPSILON = 1e-6
-
-# helper functions
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-# flash attention forwards and backwards
-
-# https://arxiv.org/abs/2205.14135
-
-
-class FlashAttentionFunction(Function):
- @ staticmethod
- @ torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
-
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
-
- scale = (q.shape[-1] ** -0.5)
-
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim=-1)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- mask,
- all_row_sums.split(q_bucket_size, dim=-2),
- all_row_maxes.split(q_bucket_size, dim=-2),
- )
-
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
-
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
-
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
-
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
-
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
-
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
-
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
-
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
-
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
-
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
-
- return o
-
- @ staticmethod
- @ torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
-
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
-
- device = q.device
-
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- do.split(q_bucket_size, dim=-2),
- mask,
- l.split(q_bucket_size, dim=-2),
- m.split(q_bucket_size, dim=-2),
- dq.split(q_bucket_size, dim=-2)
- )
-
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- dk.split(k_bucket_size, dim=-2),
- dv.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- exp_attn_weights = torch.exp(attn_weights - mc)
-
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
-
- p = exp_attn_weights / lc
-
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
-
- D = (doc * oc).sum(dim=-1, keepdims=True)
- ds = p * scale * (dp - D)
-
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
-
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
-
- return dq, dk, dv, None, None, None, None
-
-
-def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
- if mem_eff_attn:
- replace_unet_cross_attn_to_memory_efficient()
- elif xformers:
- replace_unet_cross_attn_to_xformers()
-
-
-def replace_unet_cross_attn_to_memory_efficient():
- print("Replace CrossAttention.forward to use FlashAttention")
- flash_func = FlashAttentionFunction
-
- def forward_flash_attn(self, x, context=None, mask=None):
- q_bucket_size = 512
- k_bucket_size = 1024
-
- h = self.heads
- q = self.to_q(x)
-
- context = context if context is not None else x
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
-
- out = rearrange(out, 'b h n d -> b n (h d)')
- return self.to_out(out)
-
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
-
-
-def replace_unet_cross_attn_to_xformers():
- print("Replace CrossAttention.forward to use xformers")
- try:
- import xformers.ops
- except ImportError:
- raise ImportError("No xformers / xformersがインストールされていないようです")
-
- def forward_xformers(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
-
- context = default(context, x)
- context = context.to(x.dtype)
-
- k_in = self.to_k(context)
- v_in = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
- return self.to_out(out)
-
- diffusers.models.attention.CrossAttention.forward = forward_xformers
-# endregion
-
-
-if __name__ == '__main__':
- # torch.cuda.set_per_process_memory_fraction(0.48)
- parser = argparse.ArgumentParser()
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
- parser.add_argument("--fine_tuning", action="store_true",
- help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
- parser.add_argument("--shuffle_caption", action="store_true",
- help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
- parser.add_argument("--dataset_repeats", type=int, default=None,
- help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
- parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
- help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
- parser.add_argument("--no_token_padding", action="store_true",
- help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
- parser.add_argument("--random_crop", action="store_true",
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
- parser.add_argument("--debug_dataset", action="store_true",
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
- parser.add_argument("--resolution", type=str, default=None,
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
- parser.add_argument("--train_batch_size", type=int, default=1,
- help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)")
- parser.add_argument("--use_8bit_adam", action="store_true",
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
- parser.add_argument("--mem_eff_attn", action="store_true",
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
- parser.add_argument("--xformers", action="store_true",
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
- parser.add_argument("--cache_latents", action="store_true",
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
- parser.add_argument("--gradient_checkpointing", action="store_true",
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
- parser.add_argument("--mixed_precision", type=str, default="no",
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
- parser.add_argument("--clip_skip", type=int, default=None,
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
-
- args = parser.parse_args()
- train(args)
diff --git a/train_db_fixed/train_db_fixed_v8.py b/train_db_fixed/train_db_fixed_v8.py
deleted file mode 100644
index 142e5106..00000000
--- a/train_db_fixed/train_db_fixed_v8.py
+++ /dev/null
@@ -1,1626 +0,0 @@
-# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
-# (c) 2022 Kohya S. @kohya_ss
-
-# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
-# enable reg images in fine-tuning, add dataset_repeats option
-# v8: supports Diffusers 0.7.2
-
-from torch.autograd.function import Function
-import argparse
-import glob
-import itertools
-import math
-import os
-import random
-
-from tqdm import tqdm
-import torch
-from torchvision import transforms
-from accelerate import Accelerator
-from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-import albumentations as albu
-import numpy as np
-from PIL import Image
-import cv2
-from einops import rearrange
-from torch import einsum
-
-# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
-TOKENIZER_PATH = "openai/clip-vit-large-patch14"
-
-# StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
-
-
-class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
- def __init__(self, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
- super().__init__()
-
- self.fine_tuning = fine_tuning
- self.train_img_path_captions = train_img_path_captions
- self.reg_img_path_captions = reg_img_path_captions
- self.tokenizer = tokenizer
- self.width, self.height = resolution
- self.size = min(self.width, self.height) # 短いほう
- self.prior_loss_weight = prior_loss_weight
- self.face_crop_aug_range = face_crop_aug_range
- self.random_crop = random_crop
- self.debug_dataset = debug_dataset
- self.shuffle_caption = shuffle_caption
- self.disable_padding = disable_padding
- self.latents_cache = None
-
- # augmentation
- flip_p = 0.5 if flip_aug else 0.0
- if color_aug:
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る
- self.aug = albu.Compose([
- albu.OneOf([
- # albu.RandomBrightnessContrast(0.05, 0.05, p=.2),
- albu.HueSaturationValue(5, 8, 0, p=.2),
- # albu.RGBShift(5, 5, 5, p=.1),
- albu.RandomGamma((95, 105), p=.5),
- ], p=.33),
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- elif flip_aug:
- self.aug = albu.Compose([
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- else:
- self.aug = None
-
- self.num_train_images = len(self.train_img_path_captions)
- self.num_reg_images = len(self.reg_img_path_captions)
-
- self.enable_reg_images = self.num_reg_images > 0
-
- if not self.enable_reg_images:
- self._length = self.num_train_images
- else:
- # 学習データの倍として、奇数ならtrain
- self._length = self.num_train_images * 2
- if self._length // 2 < self.num_reg_images:
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
-
- self.image_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- def load_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- img = np.array(image, np.uint8)
-
- face_cx = face_cy = face_w = face_h = 0
- if self.face_crop_aug_range is not None:
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
- if len(tokens) >= 5:
- face_cx = int(tokens[-4])
- face_cy = int(tokens[-3])
- face_w = int(tokens[-2])
- face_h = int(tokens[-1])
-
- return img, face_cx, face_cy, face_w, face_h
-
- # いい感じに切り出す
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
- height, width = image.shape[0:2]
- if height == self.height and width == self.width:
- return image
-
- # 画像サイズはsizeより大きいのでリサイズする
- face_size = max(face_w, face_h)
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
- if min_scale >= max_scale: # range指定がmin==max
- scale = min_scale
- else:
- scale = random.uniform(min_scale, max_scale)
-
- nh = int(height * scale + .5)
- nw = int(width * scale + .5)
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
- face_cx = int(face_cx * scale + .5)
- face_cy = int(face_cy * scale + .5)
- height, width = nh, nw
-
- # 顔を中心として448*640とかへを切り出す
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
-
- if self.random_crop:
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
- else:
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
- if face_size > self.size // 10 and face_size >= 40:
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
-
- p1 = max(0, min(p1, length - target_size))
-
- if axis == 0:
- image = image[p1:p1 + target_size, :]
- else:
- image = image[:, p1:p1 + target_size]
-
- return image
-
- def __len__(self):
- return self._length
-
- def set_cached_latents(self, image_path, latents):
- if self.latents_cache is None:
- self.latents_cache = {}
- self.latents_cache[image_path] = latents
-
- def __getitem__(self, index_arg):
- example = {}
-
- if not self.enable_reg_images:
- index = index_arg
- img_path_captions = self.train_img_path_captions
- reg = False
- else:
- # 偶数ならtrain、奇数ならregを返す
- if index_arg % 2 == 0:
- img_path_captions = self.train_img_path_captions
- reg = False
- else:
- img_path_captions = self.reg_img_path_captions
- reg = True
- index = index_arg // 2
- example['loss_weight'] = 1.0 if (not reg or self.fine_tuning) else self.prior_loss_weight
-
- index = index % len(img_path_captions)
- image_path, caption = img_path_captions[index]
- example['image_path'] = image_path
-
- # image/latentsを処理する
- if self.latents_cache is not None and image_path in self.latents_cache:
- # latentsはキャッシュ済み
- example['latents'] = self.latents_cache[image_path]
- else:
- # 画像を読み込み必要ならcropする
- img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
- im_h, im_w = img.shape[0:2]
- if face_cx > 0: # 顔位置情報あり
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
- elif im_h > self.height or im_w > self.width:
- assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください"
- if im_h > self.height:
- p = random.randint(0, im_h - self.height)
- img = img[p:p + self.height]
- if im_w > self.width:
- p = random.randint(0, im_w - self.width)
- img = img[:, p:p + self.width]
-
- im_h, im_w = img.shape[0:2]
- assert im_h == self.height and im_w == self.width, f"image too small / 画像サイズが小さいようです: {image_path}"
-
- # augmentation
- if self.aug is not None:
- img = self.aug(image=img)['image']
-
- example['image'] = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
-
- # captionを処理する
- if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
- tokens = caption.strip().split(",")
- random.shuffle(tokens)
- caption = ",".join(tokens).strip()
-
- input_ids = self.tokenizer(caption, padding="do_not_pad", truncation=True,
- max_length=self.tokenizer.model_max_length).input_ids
-
- # padしてTensor変換
- if self.disable_padding:
- # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
- input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
- else:
- # paddingする
- input_ids = self.tokenizer.pad({"input_ids": input_ids}, padding='max_length', max_length=self.tokenizer.model_max_length,
- return_tensors='pt').input_ids
-
- example['input_ids'] = input_ids
-
- if self.debug_dataset:
- example['caption'] = caption
- return example
-
-
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def convert_ldm_unet_checkpoint(checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint(checkpoint):
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
-
- keys = list(checkpoint.keys())
-
- text_model_dict = {}
-
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
-
- text_model.load_state_dict(text_model_dict)
-
- return text_model
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def convert_unet_state_dict(unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(ckpt_path):
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config()
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- unet.load_state_dict(converted_unet_checkpoint)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- vae.load_state_dict(converted_vae_checkpoint)
-
- # convert text_model
- text_model = convert_ldm_clip_checkpoint(state_dict)
-
- return text_model, vae, unet
-
-
-def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict(unet.state_dict())
- for k, v in unet_state_dict.items():
- key = "model.diffusion_model." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- state_dict[key] = v
-
- # Convert the text encoder model
- text_enc_dict = text_encoder.state_dict() # 変換不要
- for k, v in text_enc_dict.items():
- key = "cond_stage_model.transformer." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- state_dict[key] = v
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-# endregion
-
-
-def collate_fn(examples):
- input_ids = [e['input_ids'] for e in examples]
- input_ids = torch.stack(input_ids)
-
- if 'latents' in examples[0]:
- pixel_values = None
- latents = [e['latents'] for e in examples]
- latents = torch.stack(latents)
- else:
- pixel_values = [e['image'] for e in examples]
- pixel_values = torch.stack(pixel_values)
- pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
- latents = None
-
- loss_weights = [e['loss_weight'] for e in examples]
- loss_weights = torch.FloatTensor(loss_weights)
-
- batch = {"input_ids": input_ids, "pixel_values": pixel_values, "latents": latents, "loss_weights": loss_weights}
- return batch
-
-
-def train(args):
- fine_tuning = args.fine_tuning
- cache_latents = args.cache_latents
-
- # latentsをキャッシュする場合のオプション設定を確認する
- if cache_latents:
- # assert args.face_crop_aug_range is None and not args.random_crop, "when caching latents, crop aug cannot be used / latentをキャッシュするときは切り出しは使えません"
- # →使えるようにしておく(初期イメージの切り出しになる)
- assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません"
-
- # モデル形式のオプション設定を確認する
- use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
- if not use_stable_diffusion_format:
- assert os.path.exists(
- args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
-
- assert args.save_every_n_epochs is None or use_stable_diffusion_format, "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
-
- if args.seed is not None:
- set_seed(args.seed)
-
- # 学習データを用意する
- def load_dreambooth_dir(dir):
- tokens = os.path.basename(dir).split('_')
- try:
- n_repeats = int(tokens[0])
- except ValueError as e:
- print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
- raise e
-
- caption = '_'.join(tokens[1:])
-
- img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg"))
- return n_repeats, [(ip, caption) for ip in img_paths]
-
- print("prepare train images.")
- train_img_path_captions = []
-
- if fine_tuning:
- img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg"))
- for img_path in tqdm(img_paths):
- # captionの候補ファイル名を作る
- base_name = os.path.splitext(img_path)[0]
- base_name_face_det = base_name
- tokens = base_name.split("_")
- if len(tokens) >= 5:
- base_name_face_det = "_".join(tokens[:-4])
- cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption']
-
- caption = None
- for cap_path in cap_paths:
- if os.path.isfile(cap_path):
- with open(cap_path, "rt", encoding='utf-8') as f:
- caption = f.readlines()[0].strip()
- break
-
- assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}"
-
- train_img_path_captions.append((img_path, caption))
-
- if args.dataset_repeats is not None:
- l = []
- for _ in range(args.dataset_repeats):
- l.extend(train_img_path_captions)
- train_img_path_captions = l
- else:
- train_dirs = os.listdir(args.train_data_dir)
- for dir in train_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
- for _ in range(n_repeats):
- train_img_path_captions.extend(img_caps)
- print(f"{len(train_img_path_captions)} train images.")
-
- reg_img_path_captions = []
- if args.reg_data_dir:
- print("prepare reg images.")
- reg_dirs = os.listdir(args.reg_data_dir)
- for dir in reg_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir))
- for _ in range(n_repeats):
- reg_img_path_captions.extend(img_caps)
- print(f"{len(reg_img_path_captions)} reg images.")
-
- if args.debug_dataset:
- # デバッグ時はshuffleして実際のデータセット使用時に近づける(学習時はdata loaderでshuffleする)
- random.shuffle(train_img_path_captions)
- random.shuffle(reg_img_path_captions)
-
- # データセットを準備する
- resolution = tuple([int(r) for r in args.resolution.split(',')])
- if len(resolution) == 1:
- resolution = (resolution[0], resolution[0])
- assert len(
- resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
-
- if args.face_crop_aug_range is not None:
- face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
- assert len(
- face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
- else:
- face_crop_aug_range = None
-
- # tokenizerを読み込む
- print("prepare tokenizer")
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
-
- print("prepare dataset")
- train_dataset = DreamBoothOrFineTuningDataset(fine_tuning, train_img_path_captions,
- reg_img_path_captions, tokenizer, resolution, args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop, args.shuffle_caption, args.no_token_padding, args.debug_dataset)
-
- if args.debug_dataset:
- print(f"Total dataset length / データセットの長さ: {len(train_dataset)}")
- print("Escape for exit. / Escキーで中断、終了します")
- for example in train_dataset:
- im = example['image']
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
- print(f'caption: "{example["caption"]}", loss weight: {example["loss_weight"]}')
- cv2.imshow("img", im)
- k = cv2.waitKey()
- cv2.destroyAllWindows()
- if k == 27:
- break
- return
-
- # acceleratorを準備する
- # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
- print("prepare accelerator")
- accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision)
-
- # モデルを読み込む
- if use_stable_diffusion_format:
- print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path)
- else:
- print("load Diffusers pretrained models")
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
- unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
-
- # モデルに xformers とか memory efficient attention を組み込む
- replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
-
- # mixed precisionに対応した型を用意しておき適宜castする
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- # 学習を準備する
- if cache_latents:
- # latentをcacheする→新しいDatasetを作るとcaptionのshuffleが効かないので元のDatasetにcacheを持つ(cascadeする手もあるが)
- print("caching latents.")
- vae.to(accelerator.device, dtype=weight_dtype)
-
- for i in tqdm(range(len(train_dataset))):
- example = train_dataset[i]
- if 'latents' not in example:
- image_path = example['image_path']
- with torch.no_grad():
- pixel_values = example["image"].unsqueeze(0).to(device=accelerator.device, dtype=weight_dtype)
- latents = vae.encode(pixel_values).latent_dist.sample().squeeze(0).to("cpu")
- train_dataset.set_cached_latents(image_path, latents)
- # assertion
- for i in range(len(train_dataset)):
- assert 'latents' in train_dataset[i], "internal error: latents not cached"
-
- del vae
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- else:
- vae.requires_grad_(False)
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- text_encoder.gradient_checkpointing_enable()
-
- # 学習に必要なクラスを準備する
- print("prepare optimizer, data loader etc.")
-
- # 8-bit Adamを使う
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
- print("use 8-bit Adam optimizer")
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
-
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
-
- # dataloaderを準備する
- # DataLoaderのプロセス数:0はメインプロセスになる
- n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, num_workers=n_workers)
-
- # lr schedulerを用意する
- lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
-
- # acceleratorがなんかよろしくやってくれるらしい
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
-
- if not cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # epoch数を計算する
- num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
-
- # 学習する
- total_batch_size = args.train_batch_size # * accelerator.num_processes
- print("running training / 学習開始")
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
- print(f" num examples / サンプル数: {len(train_dataset)}")
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
- print(f" num epochs / epoch数: {num_train_epochs}")
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
- print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
-
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
- global_step = 0
-
- noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
-
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
-
- # 以下 train_dreambooth.py からほぼコピペ
- for epoch in range(num_train_epochs):
- print(f"epoch {epoch+1}/{num_train_epochs}")
- unet.train()
- text_encoder.train() # なんかunetだけでいいらしい?→最新版で修正されてた(;´Д`) いろいろ雑だな
-
- loss_total = 0
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet):
- with torch.no_grad():
- # latentに変換
- if cache_latents:
- latents = batch["latents"].to(accelerator.device)
- else:
- latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents, device=latents.device)
- b_size = latents.shape[0]
-
- # Sample a random timestep for each image
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- if args.clip_skip is None:
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
- else:
- enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
-
- # Predict the noise residual
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="none")
- loss = loss.mean([1, 2, 3])
-
- loss_weights = batch["loss_weights"] # 各sampleごとのweight
- loss = loss * loss_weights
-
- loss = loss.mean()
-
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=True)
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- current_loss = loss.detach().item()
- loss_total += current_loss
- avr_loss = loss_total / (step+1)
- logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- # accelerator.log(logs, step=global_step)
-
- if global_step >= args.max_train_steps:
- break
-
- accelerator.wait_for_everyone()
-
- if use_stable_diffusion_format and args.save_every_n_epochs is not None:
- if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
- print("saving check point.")
- os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step)
-
- is_main_process = accelerator.is_main_process
- if is_main_process:
- unet = accelerator.unwrap_model(unet)
- text_encoder = accelerator.unwrap_model(text_encoder)
-
- accelerator.end_training()
- del accelerator # この後メモリを使うのでこれは消す
-
- if is_main_process:
- os.makedirs(args.output_dir, exist_ok=True)
- if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet, args.pretrained_model_name_or_path, epoch, global_step)
- else:
- # Create the pipeline using using the trained modules and save it.
- print(f"save trained model as Diffusers to {args.output_dir}")
- pipeline = StableDiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- text_encoder=text_encoder,
- )
- pipeline.save_pretrained(args.output_dir)
- print("model saved.")
-
-
-# region モジュール入れ替え部
-"""
-高速化のためのモジュール入れ替え
-"""
-
-# FlashAttentionを使うCrossAttention
-# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
-# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
-
-# constants
-
-EPSILON = 1e-6
-
-# helper functions
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-# flash attention forwards and backwards
-
-# https://arxiv.org/abs/2205.14135
-
-
-class FlashAttentionFunction(Function):
- @ staticmethod
- @ torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
-
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
-
- scale = (q.shape[-1] ** -0.5)
-
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim=-1)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- mask,
- all_row_sums.split(q_bucket_size, dim=-2),
- all_row_maxes.split(q_bucket_size, dim=-2),
- )
-
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
-
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
-
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
-
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
-
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
-
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
-
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
-
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
-
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
-
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
-
- return o
-
- @ staticmethod
- @ torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
-
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
-
- device = q.device
-
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- do.split(q_bucket_size, dim=-2),
- mask,
- l.split(q_bucket_size, dim=-2),
- m.split(q_bucket_size, dim=-2),
- dq.split(q_bucket_size, dim=-2)
- )
-
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- dk.split(k_bucket_size, dim=-2),
- dv.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- exp_attn_weights = torch.exp(attn_weights - mc)
-
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
-
- p = exp_attn_weights / lc
-
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
-
- D = (doc * oc).sum(dim=-1, keepdims=True)
- ds = p * scale * (dp - D)
-
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
-
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
-
- return dq, dk, dv, None, None, None, None
-
-
-def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
- if mem_eff_attn:
- replace_unet_cross_attn_to_memory_efficient()
- elif xformers:
- replace_unet_cross_attn_to_xformers()
-
-
-def replace_unet_cross_attn_to_memory_efficient():
- print("Replace CrossAttention.forward to use FlashAttention")
- flash_func = FlashAttentionFunction
-
- def forward_flash_attn(self, x, context=None, mask=None):
- q_bucket_size = 512
- k_bucket_size = 1024
-
- h = self.heads
- q = self.to_q(x)
-
- context = context if context is not None else x
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
-
- out = rearrange(out, 'b h n d -> b n (h d)')
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
-
-
-def replace_unet_cross_attn_to_xformers():
- print("Replace CrossAttention.forward to use xformers")
- try:
- import xformers.ops
- except ImportError:
- raise ImportError("No xformers / xformersがインストールされていないようです")
-
- def forward_xformers(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
-
- context = default(context, x)
- context = context.to(x.dtype)
-
- k_in = self.to_k(context)
- v_in = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_xformers
-# endregion
-
-
-if __name__ == '__main__':
- # torch.cuda.set_per_process_memory_fraction(0.48)
- parser = argparse.ArgumentParser()
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
- parser.add_argument("--fine_tuning", action="store_true",
- help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
- parser.add_argument("--shuffle_caption", action="store_true",
- help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
- parser.add_argument("--dataset_repeats", type=int, default=None,
- help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
- parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
- help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
- parser.add_argument("--no_token_padding", action="store_true",
- help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
- parser.add_argument("--random_crop", action="store_true",
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
- parser.add_argument("--debug_dataset", action="store_true",
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
- parser.add_argument("--resolution", type=str, default=None,
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
- parser.add_argument("--train_batch_size", type=int, default=1,
- help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)")
- parser.add_argument("--use_8bit_adam", action="store_true",
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
- parser.add_argument("--mem_eff_attn", action="store_true",
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
- parser.add_argument("--xformers", action="store_true",
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
- parser.add_argument("--cache_latents", action="store_true",
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
- parser.add_argument("--gradient_checkpointing", action="store_true",
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
- parser.add_argument("--mixed_precision", type=str, default="no",
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
- parser.add_argument("--clip_skip", type=int, default=None,
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
-
- args = parser.parse_args()
- train(args)
diff --git a/train_db_fixed/train_db_fixed_v9.py b/train_db_fixed/train_db_fixed_v9.py
deleted file mode 100644
index b6b5bdde..00000000
--- a/train_db_fixed/train_db_fixed_v9.py
+++ /dev/null
@@ -1,1803 +0,0 @@
-# このスクリプトのライセンスは、train_dreambooth.pyと同じくApache License 2.0とします
-# (c) 2022 Kohya S. @kohya_ss
-
-# v7: another text encoder ckpt format, average loss, save epochs/global steps, show num of train/reg images,
-# enable reg images in fine-tuning, add dataset_repeats option
-# v8: supports Diffusers 0.7.2
-# v9: add bucketing option
-
-import time
-from torch.autograd.function import Function
-import argparse
-import glob
-import itertools
-import math
-import os
-import random
-
-from tqdm import tqdm
-import torch
-from torchvision import transforms
-from accelerate import Accelerator
-from accelerate.utils import set_seed
-from transformers import CLIPTextModel, CLIPTokenizer
-import diffusers
-from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
-import albumentations as albu
-import numpy as np
-from PIL import Image
-import cv2
-from einops import rearrange
-from torch import einsum
-
-# Tokenizer: checkpointから読み込むのではなくあらかじめ提供されているものを使う
-TOKENIZER_PATH = "openai/clip-vit-large-patch14"
-
-# StableDiffusionのモデルパラメータ
-NUM_TRAIN_TIMESTEPS = 1000
-BETA_START = 0.00085
-BETA_END = 0.0120
-
-UNET_PARAMS_MODEL_CHANNELS = 320
-UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4]
-UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1]
-UNET_PARAMS_IMAGE_SIZE = 32 # unused
-UNET_PARAMS_IN_CHANNELS = 4
-UNET_PARAMS_OUT_CHANNELS = 4
-UNET_PARAMS_NUM_RES_BLOCKS = 2
-UNET_PARAMS_CONTEXT_DIM = 768
-UNET_PARAMS_NUM_HEADS = 8
-
-VAE_PARAMS_Z_CHANNELS = 4
-VAE_PARAMS_RESOLUTION = 256
-VAE_PARAMS_IN_CHANNELS = 3
-VAE_PARAMS_OUT_CH = 3
-VAE_PARAMS_CH = 128
-VAE_PARAMS_CH_MULT = [1, 2, 4, 4]
-VAE_PARAMS_NUM_RES_BLOCKS = 2
-
-# checkpointファイル名
-LAST_CHECKPOINT_NAME = "last.ckpt"
-LAST_STATE_NAME = "last-state"
-EPOCH_CHECKPOINT_NAME = "epoch-{:06d}.ckpt"
-EPOCH_STATE_NAME = "epoch-{:06d}-state"
-
-
-def make_bucket_resolutions(max_reso, min_size=256, max_size=1024, divisible=64):
- max_width, max_height = max_reso
- max_area = (max_width // divisible) * (max_height // divisible)
-
- resos = set()
-
- size = int(math.sqrt(max_area)) * divisible
- resos.add((size, size))
-
- size = min_size
- while size <= max_size:
- width = size
- height = min(max_size, (max_area // (width // divisible)) * divisible)
- resos.add((width, height))
- resos.add((height, width))
- size += divisible
-
- resos = list(resos)
- resos.sort()
-
- aspect_ratios = [w / h for w, h in resos]
- return resos, aspect_ratios
-
-
-class DreamBoothOrFineTuningDataset(torch.utils.data.Dataset):
- def __init__(self, batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution, prior_loss_weight, flip_aug, color_aug, face_crop_aug_range, random_crop, shuffle_caption, disable_padding, debug_dataset) -> None:
- super().__init__()
-
- self.batch_size = batch_size
- self.fine_tuning = fine_tuning
- self.train_img_path_captions = train_img_path_captions
- self.reg_img_path_captions = reg_img_path_captions
- self.tokenizer = tokenizer
- self.width, self.height = resolution
- self.size = min(self.width, self.height) # 短いほう
- self.prior_loss_weight = prior_loss_weight
- self.face_crop_aug_range = face_crop_aug_range
- self.random_crop = random_crop
- self.debug_dataset = debug_dataset
- self.shuffle_caption = shuffle_caption
- self.disable_padding = disable_padding
- self.latents_cache = None
- self.enable_bucket = False
-
- # augmentation
- flip_p = 0.5 if flip_aug else 0.0
- if color_aug:
- # わりと弱めの色合いaugmentation:brightness/contrastあたりは画像のpixel valueの最大値・最小値を変えてしまうのでよくないのではという想定でgamma/hue/saturationあたりを触る
- self.aug = albu.Compose([
- albu.OneOf([
- # albu.RandomBrightnessContrast(0.05, 0.05, p=.2),
- albu.HueSaturationValue(5, 8, 0, p=.2),
- # albu.RGBShift(5, 5, 5, p=.1),
- albu.RandomGamma((95, 105), p=.5),
- ], p=.33),
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- elif flip_aug:
- self.aug = albu.Compose([
- albu.HorizontalFlip(p=flip_p)
- ], p=1.)
- else:
- self.aug = None
-
- self.num_train_images = len(self.train_img_path_captions)
- self.num_reg_images = len(self.reg_img_path_captions)
-
- self.enable_reg_images = self.num_reg_images > 0
-
- if self.enable_reg_images and self.num_train_images < self.num_reg_images:
- print("some of reg images are not used / 正則化画像の数が多いので、一部使用されない正則化画像があります")
-
- self.image_transforms = transforms.Compose(
- [
- transforms.ToTensor(),
- transforms.Normalize([0.5], [0.5]),
- ]
- )
-
- # bucketingを行わない場合も呼び出し必須(ひとつだけbucketを作る)
- def make_buckets_with_caching(self, enable_bucket, vae):
- self.enable_bucket = enable_bucket
-
- cache_latents = vae is not None
- if cache_latents:
- if enable_bucket:
- print("cache latents with bucketing")
- else:
- print("cache latents")
- else:
- if enable_bucket:
- print("make buckets")
- else:
- print("prepare dataset")
-
- # bucketingを用意する
- if enable_bucket:
- bucket_resos, bucket_aspect_ratios = make_bucket_resolutions((self.width, self.height))
- else:
- # bucketはひとつだけ、すべての画像は同じ解像度
- bucket_resos = [(self.width, self.height)]
- bucket_aspect_ratios = [self.width / self.height]
- bucket_aspect_ratios = np.array(bucket_aspect_ratios)
-
- # 画像の解像度、latentをあらかじめ取得する
- img_ar_errors = []
- self.size_lat_cache = {}
- for image_path, _ in tqdm(self.train_img_path_captions + self.reg_img_path_captions):
- if image_path in self.size_lat_cache:
- continue
-
- image = self.load_image(image_path)[0]
- image_height, image_width = image.shape[0:2]
-
- if not enable_bucket:
- # assert image_width == self.width and image_height == self.height, \
- # f"all images must have specific resolution when bucketing is disabled / bucketを使わない場合、すべての画像のサイズを統一してください: {image_path}"
- reso = (self.width, self.height)
- else:
- # bucketを決める
- aspect_ratio = image_width / image_height
- ar_errors = bucket_aspect_ratios - aspect_ratio
- bucket_id = np.abs(ar_errors).argmin()
- reso = bucket_resos[bucket_id]
- ar_error = ar_errors[bucket_id]
- img_ar_errors.append(ar_error)
-
- if cache_latents:
- image = self.resize_and_trim(image, reso)
-
- # latentを取得する
- if cache_latents:
- img_tensor = self.image_transforms(image)
- img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
- latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
- else:
- latents = None
-
- self.size_lat_cache[image_path] = (reso, latents)
-
- # 画像をbucketに分割する
- self.buckets = [[] for _ in range(len(bucket_resos))]
- reso_to_index = {}
- for i, reso in enumerate(bucket_resos):
- reso_to_index[reso] = i
-
- def split_to_buckets(is_reg, img_path_captions):
- for image_path, caption in img_path_captions:
- reso, _ = self.size_lat_cache[image_path]
- bucket_index = reso_to_index[reso]
- self.buckets[bucket_index].append((is_reg, image_path, caption))
-
- split_to_buckets(False, self.train_img_path_captions)
-
- if self.enable_reg_images:
- l = []
- while len(l) < len(self.train_img_path_captions):
- l += self.reg_img_path_captions
- l = l[:len(self.train_img_path_captions)]
- split_to_buckets(True, l)
-
- if enable_bucket:
- print("number of images with repeats / 繰り返し回数込みの各bucketの画像枚数")
- for i, (reso, imgs) in enumerate(zip(bucket_resos, self.buckets)):
- print(f"bucket {i}: resolution {reso}, count: {len(imgs)}")
- img_ar_errors = np.array(img_ar_errors)
- print(f"mean ar error: {np.mean(np.abs(img_ar_errors))}")
-
- # 参照用indexを作る
- self.buckets_indices = []
- for bucket_index, bucket in enumerate(self.buckets):
- batch_count = int(math.ceil(len(bucket) / self.batch_size))
- for batch_index in range(batch_count):
- self.buckets_indices.append((bucket_index, batch_index))
-
- self.shuffle_buckets()
- self._length = len(self.buckets_indices)
-
- # どのサイズにリサイズするか→トリミングする方向で
- def resize_and_trim(self, image, reso):
- image_height, image_width = image.shape[0:2]
- ar_img = image_width / image_height
- ar_reso = reso[0] / reso[1]
- if ar_img > ar_reso: # 横が長い→縦を合わせる
- scale = reso[1] / image_height
- else:
- scale = reso[0] / image_width
- resized_size = (int(image_width * scale + .5), int(image_height * scale + .5))
-
- image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
- if resized_size[0] > reso[0]:
- trim_size = resized_size[0] - reso[0]
- image = image[:, trim_size//2:trim_size//2 + reso[0]]
- elif resized_size[1] > reso[1]:
- trim_size = resized_size[1] - reso[1]
- image = image[trim_size//2:trim_size//2 + reso[1]]
- assert image.shape[0] == reso[1] and image.shape[1] == reso[0], \
- f"internal error, illegal trimmed size: {image.shape}, {reso}"
- return image
-
- def shuffle_buckets(self):
- random.shuffle(self.buckets_indices)
- for bucket in self.buckets:
- random.shuffle(bucket)
-
- def load_image(self, image_path):
- image = Image.open(image_path)
- if not image.mode == "RGB":
- image = image.convert("RGB")
- img = np.array(image, np.uint8)
-
- face_cx = face_cy = face_w = face_h = 0
- if self.face_crop_aug_range is not None:
- tokens = os.path.splitext(os.path.basename(image_path))[0].split('_')
- if len(tokens) >= 5:
- face_cx = int(tokens[-4])
- face_cy = int(tokens[-3])
- face_w = int(tokens[-2])
- face_h = int(tokens[-1])
-
- return img, face_cx, face_cy, face_w, face_h
-
- # いい感じに切り出す
- def crop_target(self, image, face_cx, face_cy, face_w, face_h):
- height, width = image.shape[0:2]
- if height == self.height and width == self.width:
- return image
-
- # 画像サイズはsizeより大きいのでリサイズする
- face_size = max(face_w, face_h)
- min_scale = max(self.height / height, self.width / width) # 画像がモデル入力サイズぴったりになる倍率(最小の倍率)
- min_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[1]))) # 指定した顔最小サイズ
- max_scale = min(1.0, max(min_scale, self.size / (face_size * self.face_crop_aug_range[0]))) # 指定した顔最大サイズ
- if min_scale >= max_scale: # range指定がmin==max
- scale = min_scale
- else:
- scale = random.uniform(min_scale, max_scale)
-
- nh = int(height * scale + .5)
- nw = int(width * scale + .5)
- assert nh >= self.height and nw >= self.width, f"internal error. small scale {scale}, {width}*{height}"
- image = cv2.resize(image, (nw, nh), interpolation=cv2.INTER_AREA)
- face_cx = int(face_cx * scale + .5)
- face_cy = int(face_cy * scale + .5)
- height, width = nh, nw
-
- # 顔を中心として448*640とかへを切り出す
- for axis, (target_size, length, face_p) in enumerate(zip((self.height, self.width), (height, width), (face_cy, face_cx))):
- p1 = face_p - target_size // 2 # 顔を中心に持ってくるための切り出し位置
-
- if self.random_crop:
- # 背景も含めるために顔を中心に置く確率を高めつつずらす
- range = max(length - face_p, face_p) # 画像の端から顔中心までの距離の長いほう
- p1 = p1 + (random.randint(0, range) + random.randint(0, range)) - range # -range ~ +range までのいい感じの乱数
- else:
- # range指定があるときのみ、すこしだけランダムに(わりと適当)
- if self.face_crop_aug_range[0] != self.face_crop_aug_range[1]:
- if face_size > self.size // 10 and face_size >= 40:
- p1 = p1 + random.randint(-face_size // 20, +face_size // 20)
-
- p1 = max(0, min(p1, length - target_size))
-
- if axis == 0:
- image = image[p1:p1 + target_size, :]
- else:
- image = image[:, p1:p1 + target_size]
-
- return image
-
- def __len__(self):
- return self._length
-
- def __getitem__(self, index):
- if index == 0:
- self.shuffle_buckets()
-
- bucket = self.buckets[self.buckets_indices[index][0]]
- image_index = self.buckets_indices[index][1] * self.batch_size
-
- latents_list = []
- images = []
- captions = []
- loss_weights = []
-
- for is_reg, image_path, caption in bucket[image_index:image_index + self.batch_size]:
- loss_weights.append(1.0 if is_reg else self.prior_loss_weight)
-
- # image/latentsを処理する
- reso, latents = self.size_lat_cache[image_path]
-
- if latents is None:
- # 画像を読み込み必要ならcropする
- img, face_cx, face_cy, face_w, face_h = self.load_image(image_path)
- im_h, im_w = img.shape[0:2]
-
- if self.enable_bucket:
- img = self.resize_and_trim(img, reso)
- else:
- if face_cx > 0: # 顔位置情報あり
- img = self.crop_target(img, face_cx, face_cy, face_w, face_h)
- elif im_h > self.height or im_w > self.width:
- assert self.random_crop, f"image too large, and face_crop_aug_range and random_crop are disabled / 画像サイズが大きいのでface_crop_aug_rangeかrandom_cropを有効にしてください"
- if im_h > self.height:
- p = random.randint(0, im_h - self.height)
- img = img[p:p + self.height]
- if im_w > self.width:
- p = random.randint(0, im_w - self.width)
- img = img[:, p:p + self.width]
-
- im_h, im_w = img.shape[0:2]
- assert im_h == self.height and im_w == self.width, f"image size is small / 画像サイズが小さいようです: {image_path}"
-
- # augmentation
- if self.aug is not None:
- img = self.aug(image=img)['image']
-
- image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる
- else:
- image = None
-
- images.append(image)
- latents_list.append(latents)
-
- # captionを処理する
- if self.fine_tuning and self.shuffle_caption: # fine tuning時にcaptionのshuffleをする
- tokens = caption.strip().split(",")
- random.shuffle(tokens)
- caption = ",".join(tokens).strip()
- captions.append(caption)
-
- # input_idsをpadしてTensor変換
- if self.disable_padding:
- # paddingしない:padding==Trueはバッチの中の最大長に合わせるだけ(やはりバグでは……?)
- input_ids = self.tokenizer(captions, padding=True, truncation=True, return_tensors="pt").input_ids
- else:
- # paddingする
- input_ids = self.tokenizer(captions, padding='max_length', truncation=True, return_tensors='pt').input_ids
-
- example = {}
- example['loss_weights'] = torch.FloatTensor(loss_weights)
- example['input_ids'] = input_ids
- if images[0] is not None:
- images = torch.stack(images)
- images = images.to(memory_format=torch.contiguous_format).float()
- else:
- images = None
- example['images'] = images
- example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None
- if self.debug_dataset:
- example['image_paths'] = [image_path for _, image_path, _ in bucket[image_index:image_index + self.batch_size]]
- example['captions'] = captions
- return example
-
-
-# region checkpoint変換、読み込み、書き込み ###############################
-
-# region StableDiffusion->Diffusersの変換コード
-# convert_original_stable_diffusion_to_diffusers をコピーしている(ASL 2.0)
-
-def shave_segments(path, n_shave_prefix_segments=1):
- """
- Removes segments. Positive values shave the first segments, negative shave the last segments.
- """
- if n_shave_prefix_segments >= 0:
- return ".".join(path.split(".")[n_shave_prefix_segments:])
- else:
- return ".".join(path.split(".")[:n_shave_prefix_segments])
-
-
-def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item.replace("in_layers.0", "norm1")
- new_item = new_item.replace("in_layers.2", "conv1")
-
- new_item = new_item.replace("out_layers.0", "norm2")
- new_item = new_item.replace("out_layers.3", "conv2")
-
- new_item = new_item.replace("emb_layers.1", "time_emb_proj")
- new_item = new_item.replace("skip_connection", "conv_shortcut")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside resnets to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- # new_item = new_item.replace('norm.weight', 'group_norm.weight')
- # new_item = new_item.replace('norm.bias', 'group_norm.bias')
-
- # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
- # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
-
- # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
- """
- Updates paths inside attentions to the new naming scheme (local renaming)
- """
- mapping = []
- for old_item in old_list:
- new_item = old_item
-
- new_item = new_item.replace("norm.weight", "group_norm.weight")
- new_item = new_item.replace("norm.bias", "group_norm.bias")
-
- new_item = new_item.replace("q.weight", "query.weight")
- new_item = new_item.replace("q.bias", "query.bias")
-
- new_item = new_item.replace("k.weight", "key.weight")
- new_item = new_item.replace("k.bias", "key.bias")
-
- new_item = new_item.replace("v.weight", "value.weight")
- new_item = new_item.replace("v.bias", "value.bias")
-
- new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
- new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
-
- new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
-
- mapping.append({"old": old_item, "new": new_item})
-
- return mapping
-
-
-def assign_to_checkpoint(
- paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
-):
- """
- This does the final conversion step: take locally converted weights and apply a global renaming
- to them. It splits attention layers, and takes into account additional replacements
- that may arise.
-
- Assigns the weights to the new checkpoint.
- """
- assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
-
- # Splits the attention layers into three variables.
- if attention_paths_to_split is not None:
- for path, path_map in attention_paths_to_split.items():
- old_tensor = old_checkpoint[path]
- channels = old_tensor.shape[0] // 3
-
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
-
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
-
- old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
-
- checkpoint[path_map["query"]] = query.reshape(target_shape)
- checkpoint[path_map["key"]] = key.reshape(target_shape)
- checkpoint[path_map["value"]] = value.reshape(target_shape)
-
- for path in paths:
- new_path = path["new"]
-
- # These have already been assigned
- if attention_paths_to_split is not None and new_path in attention_paths_to_split:
- continue
-
- # Global renaming happens here
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
-
- if additional_replacements is not None:
- for replacement in additional_replacements:
- new_path = new_path.replace(replacement["old"], replacement["new"])
-
- # proj_attn.weight has to be converted from conv 1D to linear
- if "proj_attn.weight" in new_path:
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
- else:
- checkpoint[new_path] = old_checkpoint[path["old"]]
-
-
-def conv_attn_to_linear(checkpoint):
- keys = list(checkpoint.keys())
- attn_keys = ["query.weight", "key.weight", "value.weight"]
- for key in keys:
- if ".".join(key.split(".")[-2:]) in attn_keys:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
- elif "proj_attn.weight" in key:
- if checkpoint[key].ndim > 2:
- checkpoint[key] = checkpoint[key][:, :, 0]
-
-
-def convert_ldm_unet_checkpoint(checkpoint, config):
- """
- Takes a state dict and a config, and returns a converted checkpoint.
- """
-
- # extract state_dict for UNet
- unet_state_dict = {}
- unet_key = "model.diffusion_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(unet_key):
- unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
-
- new_checkpoint = {}
-
- new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
- new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
- new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
- new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
-
- new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
- new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
-
- new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
- new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
- new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
- new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
-
- # Retrieves the keys for the input blocks only
- num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
- input_blocks = {
- layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
- for layer_id in range(num_input_blocks)
- }
-
- # Retrieves the keys for the middle blocks only
- num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
- middle_blocks = {
- layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
- for layer_id in range(num_middle_blocks)
- }
-
- # Retrieves the keys for the output blocks only
- num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
- output_blocks = {
- layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
- for layer_id in range(num_output_blocks)
- }
-
- for i in range(1, num_input_blocks):
- block_id = (i - 1) // (config["layers_per_block"] + 1)
- layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
-
- resnets = [
- key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
- ]
- attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
-
- if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.weight"
- )
- new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
- f"input_blocks.{i}.0.op.bias"
- )
-
- paths = renew_resnet_paths(resnets)
- meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- resnet_0 = middle_blocks[0]
- attentions = middle_blocks[1]
- resnet_1 = middle_blocks[2]
-
- resnet_0_paths = renew_resnet_paths(resnet_0)
- assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
-
- resnet_1_paths = renew_resnet_paths(resnet_1)
- assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
-
- attentions_paths = renew_attention_paths(attentions)
- meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(
- attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- for i in range(num_output_blocks):
- block_id = i // (config["layers_per_block"] + 1)
- layer_in_block_id = i % (config["layers_per_block"] + 1)
- output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
- output_block_list = {}
-
- for layer in output_block_layers:
- layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
- if layer_id in output_block_list:
- output_block_list[layer_id].append(layer_name)
- else:
- output_block_list[layer_id] = [layer_name]
-
- if len(output_block_list) > 1:
- resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
- attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
-
- resnet_0_paths = renew_resnet_paths(resnets)
- paths = renew_resnet_paths(resnets)
-
- meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
-
- if ["conv.weight", "conv.bias"] in output_block_list.values():
- index = list(output_block_list.values()).index(["conv.weight", "conv.bias"])
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.weight"
- ]
- new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
- f"output_blocks.{i}.{index}.conv.bias"
- ]
-
- # Clear attentions as they have been attributed above.
- if len(attentions) == 2:
- attentions = []
-
- if len(attentions):
- paths = renew_attention_paths(attentions)
- meta_path = {
- "old": f"output_blocks.{i}.1",
- "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
- }
- assign_to_checkpoint(
- paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
- )
- else:
- resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
- for path in resnet_0_paths:
- old_path = ".".join(["output_blocks", str(i), path["old"]])
- new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
-
- new_checkpoint[new_path] = unet_state_dict[old_path]
-
- return new_checkpoint
-
-
-def convert_ldm_vae_checkpoint(checkpoint, config):
- # extract state dict for VAE
- vae_state_dict = {}
- vae_key = "first_stage_model."
- keys = list(checkpoint.keys())
- for key in keys:
- if key.startswith(vae_key):
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
-
- new_checkpoint = {}
-
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
-
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
-
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
-
- # Retrieves the keys for the encoder down blocks only
- num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
- down_blocks = {
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
- }
-
- # Retrieves the keys for the decoder up blocks only
- num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
- up_blocks = {
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
- }
-
- for i in range(num_down_blocks):
- resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
-
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.weight"
- )
- new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
- f"encoder.down.{i}.downsample.conv.bias"
- )
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
-
- for i in range(num_up_blocks):
- block_id = num_up_blocks - 1 - i
- resnets = [
- key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
- ]
-
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.weight"
- ]
- new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
- f"decoder.up.{block_id}.upsample.conv.bias"
- ]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
- num_mid_res_blocks = 2
- for i in range(1, num_mid_res_blocks + 1):
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
-
- paths = renew_vae_resnet_paths(resnets)
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
-
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
- paths = renew_vae_attention_paths(mid_attentions)
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
- assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
- conv_attn_to_linear(new_checkpoint)
- return new_checkpoint
-
-
-def create_unet_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # unet_params = original_config.model.params.unet_config.params
-
- block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT]
-
- down_block_types = []
- resolution = 1
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D"
- down_block_types.append(block_type)
- if i != len(block_out_channels) - 1:
- resolution *= 2
-
- up_block_types = []
- for i in range(len(block_out_channels)):
- block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D"
- up_block_types.append(block_type)
- resolution //= 2
-
- config = dict(
- sample_size=UNET_PARAMS_IMAGE_SIZE,
- in_channels=UNET_PARAMS_IN_CHANNELS,
- out_channels=UNET_PARAMS_OUT_CHANNELS,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS,
- cross_attention_dim=UNET_PARAMS_CONTEXT_DIM,
- attention_head_dim=UNET_PARAMS_NUM_HEADS,
- )
-
- return config
-
-
-def create_vae_diffusers_config():
- """
- Creates a config for the diffusers based on the config of the LDM model.
- """
- # vae_params = original_config.model.params.first_stage_config.params.ddconfig
- # _ = original_config.model.params.first_stage_config.params.embed_dim
- block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT]
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
-
- config = dict(
- sample_size=VAE_PARAMS_RESOLUTION,
- in_channels=VAE_PARAMS_IN_CHANNELS,
- out_channels=VAE_PARAMS_OUT_CH,
- down_block_types=tuple(down_block_types),
- up_block_types=tuple(up_block_types),
- block_out_channels=tuple(block_out_channels),
- latent_channels=VAE_PARAMS_Z_CHANNELS,
- layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS,
- )
- return config
-
-
-def convert_ldm_clip_checkpoint(checkpoint):
- text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
-
- keys = list(checkpoint.keys())
-
- text_model_dict = {}
-
- for key in keys:
- if key.startswith("cond_stage_model.transformer"):
- text_model_dict[key[len("cond_stage_model.transformer."):]] = checkpoint[key]
-
- text_model.load_state_dict(text_model_dict)
-
- return text_model
-
-# endregion
-
-
-# region Diffusers->StableDiffusion の変換コード
-# convert_diffusers_to_original_stable_diffusion をコピーしている(ASL 2.0)
-
-def convert_unet_state_dict(unet_state_dict):
- unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
- ]
-
- unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
- ]
-
- unet_conversion_map_layer = []
- for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
- hf_mid_atn_prefix = "mid_block.attentions.0."
- sd_mid_atn_prefix = "middle_block.1."
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
- for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2*j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
- return new_state_dict
-
-# endregion
-
-
-def load_checkpoint_with_conversion(ckpt_path):
- # text encoderの格納形式が違うモデルに対応する ('text_model'がない)
- TEXT_ENCODER_KEY_REPLACEMENTS = [
- ('cond_stage_model.transformer.embeddings.', 'cond_stage_model.transformer.text_model.embeddings.'),
- ('cond_stage_model.transformer.encoder.', 'cond_stage_model.transformer.text_model.encoder.'),
- ('cond_stage_model.transformer.final_layer_norm.', 'cond_stage_model.transformer.text_model.final_layer_norm.')
- ]
-
- checkpoint = torch.load(ckpt_path, map_location="cpu")
- state_dict = checkpoint["state_dict"]
-
- key_reps = []
- for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS:
- for key in state_dict.keys():
- if key.startswith(rep_from):
- new_key = rep_to + key[len(rep_from):]
- key_reps.append((key, new_key))
-
- for key, new_key in key_reps:
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
-
- return checkpoint
-
-
-def load_models_from_stable_diffusion_checkpoint(ckpt_path):
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet2DConditionModel model.
- unet_config = create_unet_diffusers_config()
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config)
-
- unet = UNet2DConditionModel(**unet_config)
- unet.load_state_dict(converted_unet_checkpoint)
-
- # Convert the VAE model.
- vae_config = create_vae_diffusers_config()
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config)
-
- vae = AutoencoderKL(**vae_config)
- vae.load_state_dict(converted_vae_checkpoint)
-
- # convert text_model
- text_model = convert_ldm_clip_checkpoint(state_dict)
-
- return text_model, vae, unet
-
-
-def save_stable_diffusion_checkpoint(output_file, text_encoder, unet, ckpt_path, epochs, steps, save_dtype=None):
- # VAEがメモリ上にないので、もう一度VAEを含めて読み込む
- checkpoint = load_checkpoint_with_conversion(ckpt_path)
- state_dict = checkpoint["state_dict"]
-
- # Convert the UNet model
- unet_state_dict = convert_unet_state_dict(unet.state_dict())
- for k, v in unet_state_dict.items():
- key = "model.diffusion_model." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Convert the text encoder model
- text_enc_dict = text_encoder.state_dict() # 変換不要
- for k, v in text_enc_dict.items():
- key = "cond_stage_model.transformer." + k
- assert key in state_dict, f"Illegal key in save SD: {key}"
- if save_dtype is not None:
- v = v.detach().clone().to("cpu").to(save_dtype)
- state_dict[key] = v
-
- # Put together new checkpoint
- new_ckpt = {'state_dict': state_dict}
-
- if 'epoch' in checkpoint:
- epochs += checkpoint['epoch']
- if 'global_step' in checkpoint:
- steps += checkpoint['global_step']
-
- new_ckpt['epoch'] = epochs
- new_ckpt['global_step'] = steps
-
- torch.save(new_ckpt, output_file)
-# endregion
-
-
-def collate_fn(examples):
- return examples[0]
-
-
-def train(args):
- fine_tuning = args.fine_tuning
- cache_latents = args.cache_latents
-
- # latentsをキャッシュする場合のオプション設定を確認する
- if cache_latents:
- # assert args.face_crop_aug_range is None and not args.random_crop, "when caching latents, crop aug cannot be used / latentをキャッシュするときは切り出しは使えません"
- # →使えるようにしておく(初期イメージの切り出しになる)
- assert not args.flip_aug and not args.color_aug, "when caching latents, augmentation cannot be used / latentをキャッシュするときはaugmentationは使えません"
-
- # モデル形式のオプション設定を確認する
- use_stable_diffusion_format = os.path.isfile(args.pretrained_model_name_or_path)
- if not use_stable_diffusion_format:
- assert os.path.exists(
- args.pretrained_model_name_or_path), f"no pretrained model / 学習元モデルがありません : {args.pretrained_model_name_or_path}"
-
- assert args.save_every_n_epochs is None or use_stable_diffusion_format, "when loading Diffusers model, save_every_n_epochs does not work / Diffusersのモデルを読み込むときにはsave_every_n_epochsオプションは無効になります"
-
- if args.seed is not None:
- set_seed(args.seed)
-
- # 学習データを用意する
- def load_dreambooth_dir(dir):
- tokens = os.path.basename(dir).split('_')
- try:
- n_repeats = int(tokens[0])
- except ValueError as e:
- # print(f"no 'n_repeats' in directory name / DreamBoothのディレクトリ名に繰り返し回数がないようです: {dir}")
- # raise e
- return 0, []
-
- caption = '_'.join(tokens[1:])
-
- print(f"found directory {n_repeats}_{caption}")
-
- img_paths = glob.glob(os.path.join(dir, "*.png")) + glob.glob(os.path.join(dir, "*.jpg"))
- return n_repeats, [(ip, caption) for ip in img_paths]
-
- print("prepare train images.")
- train_img_path_captions = []
-
- if fine_tuning:
- img_paths = glob.glob(os.path.join(args.train_data_dir, "*.png")) + glob.glob(os.path.join(args.train_data_dir, "*.jpg"))
- for img_path in tqdm(img_paths):
- # captionの候補ファイル名を作る
- base_name = os.path.splitext(img_path)[0]
- base_name_face_det = base_name
- tokens = base_name.split("_")
- if len(tokens) >= 5:
- base_name_face_det = "_".join(tokens[:-4])
- cap_paths = [base_name + '.txt', base_name + '.caption', base_name_face_det+'.txt', base_name_face_det+'.caption']
-
- caption = None
- for cap_path in cap_paths:
- if os.path.isfile(cap_path):
- with open(cap_path, "rt", encoding='utf-8') as f:
- caption = f.readlines()[0].strip()
- break
-
- assert caption is not None and len(caption) > 0, f"no caption / キャプションファイルが見つからないか、captionが空です: {cap_paths}"
-
- train_img_path_captions.append((img_path, caption))
-
- if args.dataset_repeats is not None:
- l = []
- for _ in range(args.dataset_repeats):
- l.extend(train_img_path_captions)
- train_img_path_captions = l
- else:
- train_dirs = os.listdir(args.train_data_dir)
- for dir in train_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.train_data_dir, dir))
- for _ in range(n_repeats):
- train_img_path_captions.extend(img_caps)
- print(f"{len(train_img_path_captions)} train images with repeating.")
-
- reg_img_path_captions = []
- if args.reg_data_dir:
- print("prepare reg images.")
- reg_dirs = os.listdir(args.reg_data_dir)
- for dir in reg_dirs:
- n_repeats, img_caps = load_dreambooth_dir(os.path.join(args.reg_data_dir, dir))
- for _ in range(n_repeats):
- reg_img_path_captions.extend(img_caps)
- print(f"{len(reg_img_path_captions)} reg images.")
-
- # データセットを準備する
- resolution = tuple([int(r) for r in args.resolution.split(',')])
- if len(resolution) == 1:
- resolution = (resolution[0], resolution[0])
- assert len(
- resolution) == 2, f"resolution must be 'size' or 'width,height' / resolutionは'サイズ'または'幅','高さ'で指定してください: {args.resolution}"
-
- if args.face_crop_aug_range is not None:
- face_crop_aug_range = tuple([float(r) for r in args.face_crop_aug_range.split(',')])
- assert len(
- face_crop_aug_range) == 2, f"face_crop_aug_range must be two floats / face_crop_aug_rangeは'下限,上限'で指定してください: {args.face_crop_aug_range}"
- else:
- face_crop_aug_range = None
-
- # tokenizerを読み込む
- print("prepare tokenizer")
- tokenizer = CLIPTokenizer.from_pretrained(TOKENIZER_PATH)
-
- print("prepare dataset")
- train_dataset = DreamBoothOrFineTuningDataset(args.train_batch_size, fine_tuning, train_img_path_captions, reg_img_path_captions, tokenizer, resolution,
- args.prior_loss_weight, args.flip_aug, args.color_aug, face_crop_aug_range, args.random_crop,
- args.shuffle_caption, args.no_token_padding, args.debug_dataset)
-
- if args.debug_dataset:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None) # デバッグ用にcacheなしで作る
- print(f"Total dataset length (steps) / データセットの長さ(ステップ数): {len(train_dataset)}")
- print("Escape for exit. / Escキーで中断、終了します")
- for example in train_dataset:
- for im, cap, lw in zip(example['images'], example['captions'], example['loss_weights']):
- im = ((im.numpy() + 1.0) * 127.5).astype(np.uint8)
- im = np.transpose(im, (1, 2, 0)) # c,H,W -> H,W,c
- im = im[:, :, ::-1] # RGB -> BGR (OpenCV)
- print(f'size: {im.shape[1]}*{im.shape[0]}, caption: "{cap}", loss weight: {lw}')
- cv2.imshow("img", im)
- k = cv2.waitKey()
- cv2.destroyAllWindows()
- if k == 27:
- break
- if k == 27:
- break
- return
-
- # acceleratorを準備する
- # gradient accumulationは複数モデルを学習する場合には対応していないとのことなので、1固定にする
- print("prepare accelerator")
- if args.logging_dir is None:
- log_with = None
- logging_dir = None
- else:
- log_with = "tensorboard"
- logging_dir = args.logging_dir + "/" + time.strftime('%Y%m%d%H%M%S', time.localtime())
- accelerator = Accelerator(gradient_accumulation_steps=1, mixed_precision=args.mixed_precision,
- log_with=log_with, logging_dir=logging_dir)
-
- # モデルを読み込む
- if use_stable_diffusion_format:
- print("load StableDiffusion checkpoint")
- text_encoder, vae, unet = load_models_from_stable_diffusion_checkpoint(args.pretrained_model_name_or_path)
- else:
- print("load Diffusers pretrained models")
- text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
- vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
- unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
-
- # モデルに xformers とか memory efficient attention を組み込む
- replace_unet_modules(unet, args.mem_eff_attn, args.xformers)
-
- # mixed precisionに対応した型を用意しておき適宜castする
- weight_dtype = torch.float32
- if args.mixed_precision == "fp16":
- weight_dtype = torch.float16
- elif args.mixed_precision == "bf16":
- weight_dtype = torch.bfloat16
-
- save_dtype = None
- if args.save_precision == "fp16":
- save_dtype = torch.float16
- elif args.save_precision == "bf16":
- save_dtype = torch.bfloat16
- elif args.save_precision == "float":
- save_dtype = torch.float32
-
- # 学習を準備する
- if cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
- with torch.no_grad():
- train_dataset.make_buckets_with_caching(args.enable_bucket, vae)
- del vae
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- else:
- train_dataset.make_buckets_with_caching(args.enable_bucket, None)
- vae.requires_grad_(False)
-
- if args.gradient_checkpointing:
- unet.enable_gradient_checkpointing()
- text_encoder.gradient_checkpointing_enable()
-
- # 学習に必要なクラスを準備する
- print("prepare optimizer, data loader etc.")
-
- # 8-bit Adamを使う
- if args.use_8bit_adam:
- try:
- import bitsandbytes as bnb
- except ImportError:
- raise ImportError("No bitsand bytes / bitsandbytesがインストールされていないようです")
- print("use 8-bit Adam optimizer")
- optimizer_class = bnb.optim.AdamW8bit
- else:
- optimizer_class = torch.optim.AdamW
-
- trainable_params = (itertools.chain(unet.parameters(), text_encoder.parameters()))
-
- # betaやweight decayはdiffusers DreamBoothもDreamBooth SDもデフォルト値のようなのでオプションはとりあえず省略
- optimizer = optimizer_class(trainable_params, lr=args.learning_rate)
-
- # dataloaderを準備する
- # DataLoaderのプロセス数:0はメインプロセスになる
- n_workers = min(8, os.cpu_count() - 1) # cpu_count-1 ただし最大8
- train_dataloader = torch.utils.data.DataLoader(
- train_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn, num_workers=n_workers)
-
- # lr schedulerを用意する
- lr_scheduler = diffusers.optimization.get_scheduler("constant", optimizer, num_training_steps=args.max_train_steps)
-
- # acceleratorがなんかよろしくやってくれるらしい
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
- unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
-
- if not cache_latents:
- vae.to(accelerator.device, dtype=weight_dtype)
-
- # resumeする
- if args.resume is not None:
- print(f"resume training from state: {args.resume}")
- accelerator.load_state(args.resume)
-
- # epoch数を計算する
- num_train_epochs = math.ceil(args.max_train_steps / len(train_dataloader))
-
- # 学習する
- total_batch_size = args.train_batch_size # * accelerator.num_processes
- print("running training / 学習開始")
- print(f" num train images * repeats / 学習画像の数×繰り返し回数: {train_dataset.num_train_images}")
- print(f" num reg images / 正則化画像の数: {train_dataset.num_reg_images}")
- print(f" num examples / サンプル数: {train_dataset.num_train_images * 2}")
- print(f" num batches per epoch / 1epochのバッチ数: {len(train_dataloader)}")
- print(f" num epochs / epoch数: {num_train_epochs}")
- print(f" batch size per device / バッチサイズ: {args.train_batch_size}")
- print(f" total train batch size (with parallel & distributed) / 総バッチサイズ(並列学習含む): {total_batch_size}")
- print(f" total optimization steps / 学習ステップ数: {args.max_train_steps}")
-
- progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process, desc="steps")
- global_step = 0
-
- noise_scheduler = DDPMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
-
- if accelerator.is_main_process:
- accelerator.init_trackers("dreambooth")
-
- # 以下 train_dreambooth.py からほぼコピペ
- for epoch in range(num_train_epochs):
- print(f"epoch {epoch+1}/{num_train_epochs}")
- unet.train()
- text_encoder.train() # なんかunetだけでいいらしい?→最新版で修正されてた(;´Д`) いろいろ雑だな
-
- loss_total = 0
- for step, batch in enumerate(train_dataloader):
- with accelerator.accumulate(unet):
- with torch.no_grad():
- # latentに変換
- if cache_latents:
- latents = batch["latents"].to(accelerator.device)
- else:
- latents = vae.encode(batch["images"].to(dtype=weight_dtype)).latent_dist.sample()
- latents = latents * 0.18215
-
- # Sample noise that we'll add to the latents
- noise = torch.randn_like(latents, device=latents.device)
- b_size = latents.shape[0]
-
- # Sample a random timestep for each image
- timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (b_size,), device=latents.device)
- timesteps = timesteps.long()
-
- # Add noise to the latents according to the noise magnitude at each timestep
- # (this is the forward diffusion process)
- noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
-
- # Get the text embedding for conditioning
- if args.clip_skip is None:
- encoder_hidden_states = text_encoder(batch["input_ids"])[0]
- else:
- enc_out = text_encoder(batch["input_ids"], output_hidden_states=True, return_dict=True)
- encoder_hidden_states = enc_out['hidden_states'][-args.clip_skip]
- encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states)
-
- # Predict the noise residual
- noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
-
- loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="none")
- loss = loss.mean([1, 2, 3])
-
- loss_weights = batch["loss_weights"] # 各sampleごとのweight
- loss = loss * loss_weights
-
- loss = loss.mean()
-
- accelerator.backward(loss)
- if accelerator.sync_gradients:
- params_to_clip = (itertools.chain(unet.parameters(), text_encoder.parameters()))
- accelerator.clip_grad_norm_(params_to_clip, 1.0) # args.max_grad_norm)
-
- optimizer.step()
- lr_scheduler.step()
- optimizer.zero_grad(set_to_none=True)
-
- # Checks if the accelerator has performed an optimization step behind the scenes
- if accelerator.sync_gradients:
- progress_bar.update(1)
- global_step += 1
-
- current_loss = loss.detach().item()
- if args.logging_dir is not None:
- logs = {"loss": current_loss, "lr": lr_scheduler.get_last_lr()[0]}
- accelerator.log(logs, step=global_step)
-
- loss_total += current_loss
- avr_loss = loss_total / (step+1)
- logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
-
- if global_step >= args.max_train_steps:
- break
-
- if args.logging_dir is not None:
- logs = {"epoch_loss": loss_total / len(train_dataloader)}
- accelerator.log(logs, step=epoch+1)
-
- accelerator.wait_for_everyone()
-
- if use_stable_diffusion_format and args.save_every_n_epochs is not None:
- if (epoch + 1) % args.save_every_n_epochs == 0 and (epoch + 1) < num_train_epochs:
- print("saving check point.")
- os.makedirs(args.output_dir, exist_ok=True)
- ckpt_file = os.path.join(args.output_dir, EPOCH_CHECKPOINT_NAME.format(epoch + 1))
- save_stable_diffusion_checkpoint(ckpt_file, accelerator.unwrap_model(text_encoder), accelerator.unwrap_model(unet),
- args.pretrained_model_name_or_path, epoch + 1, global_step, save_dtype)
-
- if args.save_state:
- print("saving state.")
- accelerator.save_state(os.path.join(args.output_dir, EPOCH_STATE_NAME.format(epoch + 1)))
-
- is_main_process = accelerator.is_main_process
- if is_main_process:
- unet = accelerator.unwrap_model(unet)
- text_encoder = accelerator.unwrap_model(text_encoder)
-
- accelerator.end_training()
-
- if args.save_state:
- print("saving last state.")
- accelerator.save_state(os.path.join(args.output_dir, LAST_STATE_NAME))
-
- del accelerator # この後メモリを使うのでこれは消す
-
- if is_main_process:
- os.makedirs(args.output_dir, exist_ok=True)
- if use_stable_diffusion_format:
- ckpt_file = os.path.join(args.output_dir, LAST_CHECKPOINT_NAME)
- print(f"save trained model as StableDiffusion checkpoint to {ckpt_file}")
- save_stable_diffusion_checkpoint(ckpt_file, text_encoder, unet,
- args.pretrained_model_name_or_path, epoch, global_step, save_dtype)
- else:
- # Create the pipeline using using the trained modules and save it.
- print(f"save trained model as Diffusers to {args.output_dir}")
- pipeline = StableDiffusionPipeline.from_pretrained(
- args.pretrained_model_name_or_path,
- unet=unet,
- text_encoder=text_encoder,
- )
- pipeline.save_pretrained(args.output_dir)
- print("model saved.")
-
-
-# region モジュール入れ替え部
-"""
-高速化のためのモジュール入れ替え
-"""
-
-# FlashAttentionを使うCrossAttention
-# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py
-# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE
-
-# constants
-
-EPSILON = 1e-6
-
-# helper functions
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-# flash attention forwards and backwards
-
-# https://arxiv.org/abs/2205.14135
-
-
-class FlashAttentionFunction(Function):
- @ staticmethod
- @ torch.no_grad()
- def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size):
- """ Algorithm 2 in the paper """
-
- device = q.device
- dtype = q.dtype
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- o = torch.zeros_like(q)
- all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device)
- all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device)
-
- scale = (q.shape[-1] ** -0.5)
-
- if not exists(mask):
- mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size)
- else:
- mask = rearrange(mask, 'b n -> b 1 1 n')
- mask = mask.split(q_bucket_size, dim=-1)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- mask,
- all_row_sums.split(q_bucket_size, dim=-2),
- all_row_maxes.split(q_bucket_size, dim=-2),
- )
-
- for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if exists(row_mask):
- attn_weights.masked_fill_(~row_mask, max_neg_value)
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- block_row_maxes = attn_weights.amax(dim=-1, keepdims=True)
- attn_weights -= block_row_maxes
- exp_weights = torch.exp(attn_weights)
-
- if exists(row_mask):
- exp_weights.masked_fill_(~row_mask, 0.)
-
- block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON)
-
- new_row_maxes = torch.maximum(block_row_maxes, row_maxes)
-
- exp_values = einsum('... i j, ... j d -> ... i d', exp_weights, vc)
-
- exp_row_max_diff = torch.exp(row_maxes - new_row_maxes)
- exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes)
-
- new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums
-
- oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values)
-
- row_maxes.copy_(new_row_maxes)
- row_sums.copy_(new_row_sums)
-
- ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size)
- ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes)
-
- return o
-
- @ staticmethod
- @ torch.no_grad()
- def backward(ctx, do):
- """ Algorithm 4 in the paper """
-
- causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args
- q, k, v, o, l, m = ctx.saved_tensors
-
- device = q.device
-
- max_neg_value = -torch.finfo(q.dtype).max
- qk_len_diff = max(k.shape[-2] - q.shape[-2], 0)
-
- dq = torch.zeros_like(q)
- dk = torch.zeros_like(k)
- dv = torch.zeros_like(v)
-
- row_splits = zip(
- q.split(q_bucket_size, dim=-2),
- o.split(q_bucket_size, dim=-2),
- do.split(q_bucket_size, dim=-2),
- mask,
- l.split(q_bucket_size, dim=-2),
- m.split(q_bucket_size, dim=-2),
- dq.split(q_bucket_size, dim=-2)
- )
-
- for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits):
- q_start_index = ind * q_bucket_size - qk_len_diff
-
- col_splits = zip(
- k.split(k_bucket_size, dim=-2),
- v.split(k_bucket_size, dim=-2),
- dk.split(k_bucket_size, dim=-2),
- dv.split(k_bucket_size, dim=-2),
- )
-
- for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits):
- k_start_index = k_ind * k_bucket_size
-
- attn_weights = einsum('... i d, ... j d -> ... i j', qc, kc) * scale
-
- if causal and q_start_index < (k_start_index + k_bucket_size - 1):
- causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool,
- device=device).triu(q_start_index - k_start_index + 1)
- attn_weights.masked_fill_(causal_mask, max_neg_value)
-
- exp_attn_weights = torch.exp(attn_weights - mc)
-
- if exists(row_mask):
- exp_attn_weights.masked_fill_(~row_mask, 0.)
-
- p = exp_attn_weights / lc
-
- dv_chunk = einsum('... i j, ... i d -> ... j d', p, doc)
- dp = einsum('... i d, ... j d -> ... i j', doc, vc)
-
- D = (doc * oc).sum(dim=-1, keepdims=True)
- ds = p * scale * (dp - D)
-
- dq_chunk = einsum('... i j, ... j d -> ... i d', ds, kc)
- dk_chunk = einsum('... i j, ... i d -> ... j d', ds, qc)
-
- dqc.add_(dq_chunk)
- dkc.add_(dk_chunk)
- dvc.add_(dv_chunk)
-
- return dq, dk, dv, None, None, None, None
-
-
-def replace_unet_modules(unet: diffusers.models.unet_2d_condition.UNet2DConditionModel, mem_eff_attn, xformers):
- if mem_eff_attn:
- replace_unet_cross_attn_to_memory_efficient()
- elif xformers:
- replace_unet_cross_attn_to_xformers()
-
-
-def replace_unet_cross_attn_to_memory_efficient():
- print("Replace CrossAttention.forward to use FlashAttention")
- flash_func = FlashAttentionFunction
-
- def forward_flash_attn(self, x, context=None, mask=None):
- q_bucket_size = 512
- k_bucket_size = 1024
-
- h = self.heads
- q = self.to_q(x)
-
- context = context if context is not None else x
- context = context.to(x.dtype)
- k = self.to_k(context)
- v = self.to_v(context)
- del context, x
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
-
- out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size)
-
- out = rearrange(out, 'b h n d -> b n (h d)')
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_flash_attn
-
-
-def replace_unet_cross_attn_to_xformers():
- print("Replace CrossAttention.forward to use xformers")
- try:
- import xformers.ops
- except ImportError:
- raise ImportError("No xformers / xformersがインストールされていないようです")
-
- def forward_xformers(self, x, context=None, mask=None):
- h = self.heads
- q_in = self.to_q(x)
-
- context = default(context, x)
- context = context.to(x.dtype)
-
- k_in = self.to_k(context)
- v_in = self.to_v(context)
-
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
- del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる
-
- out = rearrange(out, 'b n h d -> b n (h d)', h=h)
-
- # diffusers 0.6.0
- if type(self.to_out) is torch.nn.Sequential:
- return self.to_out(out)
-
- # diffusers 0.7.0~
- out = self.to_out[0](out)
- out = self.to_out[1](out)
- return out
-
- diffusers.models.attention.CrossAttention.forward = forward_xformers
-# endregion
-
-
-if __name__ == '__main__':
- # torch.cuda.set_per_process_memory_fraction(0.48)
- parser = argparse.ArgumentParser()
- parser.add_argument("--pretrained_model_name_or_path", type=str, default=None,
- help="pretrained model to train, directory to Diffusers model or StableDiffusion checkpoint / 学習元モデル、Diffusers形式モデルのディレクトリまたはStableDiffusionのckptファイル")
- parser.add_argument("--fine_tuning", action="store_true",
- help="fine tune the model instead of DreamBooth / DreamBoothではなくfine tuningする")
- parser.add_argument("--shuffle_caption", action="store_true",
- help="shuffle comma-separated caption when fine tuning / fine tuning時にコンマで区切られたcaptionの各要素をshuffleする")
- parser.add_argument("--train_data_dir", type=str, default=None, help="directory for train images / 学習画像データのディレクトリ")
- parser.add_argument("--reg_data_dir", type=str, default=None, help="directory for regularization images / 正則化画像データのディレクトリ")
- parser.add_argument("--dataset_repeats", type=int, default=None,
- help="repeat dataset in fine tuning / fine tuning時にデータセットを繰り返す回数")
- parser.add_argument("--output_dir", type=str, default=None,
- help="directory to output trained model, save as same format as input / 学習後のモデル出力先ディレクトリ(入力と同じ形式で保存)")
- parser.add_argument("--save_every_n_epochs", type=int, default=None,
- help="save checkpoint every N epochs (only supports in StableDiffusion checkpoint) / 学習中のモデルを指定エポックごとに保存します(StableDiffusion形式のモデルを読み込んだ場合のみ有効)")
- parser.add_argument("--save_state", action="store_true",
- help="save training state additionally (including optimizer states etc.) / optimizerなど学習状態も含めたstateを追加で保存する")
- parser.add_argument("--resume", type=str, default=None,
- help="saved state to resume training / 学習再開するモデルのstate")
- parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="loss weight for regularization images / 正則化画像のlossの重み")
- parser.add_argument("--no_token_padding", action="store_true",
- help="disable token padding (same as Diffuser's DreamBooth) / トークンのpaddingを無効にする(Diffusers版DreamBoothと同じ動作)")
- parser.add_argument("--color_aug", action="store_true", help="enable weak color augmentation / 学習時に色合いのaugmentationを有効にする")
- parser.add_argument("--flip_aug", action="store_true", help="enable horizontal flip augmentation / 学習時に左右反転のaugmentationを有効にする")
- parser.add_argument("--face_crop_aug_range", type=str, default=None,
- help="enable face-centered crop augmentation and its range (e.g. 2.0,4.0) / 学習時に顔を中心とした切り出しaugmentationを有効にするときは倍率を指定する(例:2.0,4.0)")
- parser.add_argument("--random_crop", action="store_true",
- help="enable random crop (for style training in face-centered crop augmentation) / ランダムな切り出しを有効にする(顔を中心としたaugmentationを行うときに画風の学習用に指定する)")
- parser.add_argument("--debug_dataset", action="store_true",
- help="show images for debugging (do not train) / デバッグ用に学習データを画面表示する(学習は行わない)")
- parser.add_argument("--resolution", type=str, default=None,
- help="resolution in training ('size' or 'width,height') / 学習時の画像解像度('サイズ'指定、または'幅,高さ'指定)")
- parser.add_argument("--train_batch_size", type=int, default=1,
- help="batch size for training (1 means one train or reg data, not train/reg pair) / 学習時のバッチサイズ(1でtrain/regをそれぞれ1件ずつ学習)")
- parser.add_argument("--use_8bit_adam", action="store_true",
- help="use 8bit Adam optimizer (requires bitsandbytes) / 8bit Adamオプティマイザを使う(bitsandbytesのインストールが必要)")
- parser.add_argument("--mem_eff_attn", action="store_true",
- help="use memory efficient attention for CrossAttention / CrossAttentionに省メモリ版attentionを使う")
- parser.add_argument("--xformers", action="store_true",
- help="use xformers for CrossAttention / CrossAttentionにxformersを使う")
- parser.add_argument("--cache_latents", action="store_true",
- help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)")
- parser.add_argument("--enable_bucket", action="store_true",
- help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする")
- parser.add_argument("--learning_rate", type=float, default=2.0e-6, help="learning rate / 学習率")
- parser.add_argument("--max_train_steps", type=int, default=1600, help="training steps / 学習ステップ数")
- parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
- parser.add_argument("--gradient_checkpointing", action="store_true",
- help="enable gradient checkpointing / grandient checkpointingを有効にする")
- parser.add_argument("--mixed_precision", type=str, default="no",
- choices=["no", "fp16", "bf16"], help="use mixed precision / 混合精度を使う場合、その精度")
- parser.add_argument("--save_precision", type=str, default=None,
- choices=[None, "float", "fp16", "bf16"], help="precision in saving / 保存時に精度を変更して保存する")
- parser.add_argument("--clip_skip", type=int, default=None,
- help="use output of nth layer from back of text encoder (n>=1) / text encoderの後ろからn番目の層の出力を用いる(nは1以上)")
- parser.add_argument("--logging_dir", type=str, default=None,
- help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する")
-
- args = parser.parse_args()
- train(args)