diff --git a/README.md b/README.md index 1dbbbd0..3b161a9 100644 --- a/README.md +++ b/README.md @@ -36,16 +36,17 @@ As a result, Sana-0.6B is very competitive with modern giant diffusion model (e. ## πŸ”₯πŸ”₯ News +- (πŸ”₯ New) \[2024/12/20\] 1.6B 2K resolution [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released: [\[BF16 pth\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) or [\[BF16 diffusers\]](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers). πŸš€ Get your 2K resolution images within 4 seconds! Find more samples in [Sana page](https://nvlabs.github.io/Sana/). - (πŸ”₯ New) \[2024/12/18\] `diffusers` supports Sana-LoRA fine-tuning! Sana-LoRA's training and convergence speed is supper fast. [\[Guidance\]](asset/docs/sana_lora_dreambooth.md) or [\[diffusers docs\]](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sana.md). - (πŸ”₯ New) \[2024/12/13\] `diffusers` has Sana! [All Sana models in diffusers safetensors](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released and diffusers pipeline `SanaPipeline`, `SanaPAGPipeline`, `DPMSolverMultistepScheduler(with FlowMatching)` are all supported now. We prepare a [Model Card](asset/docs/model_zoo.md) for you to choose. - (πŸ”₯ New) \[2024/12/10\] 1.6B BF16 [Sana model](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) is released for stable fine-tuning. - (πŸ”₯ New) \[2024/12/9\] We release the [ComfyUI node](https://github.com/Efficient-Large-Model/ComfyUI_ExtraModels) for Sana. [\[Guidance\]](asset/docs/ComfyUI/comfyui.md) -- (πŸ”₯ New) \[2024/11\] All multi-linguistic (Emoji & Chinese & English) SFT models are released: [1.6B-512px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing), [1.6B-1024px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing), [600M-512px](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px), [600M-1024px](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px). The metric performance is shown [here](#performance) -- (πŸ”₯ New) \[2024/11\] Sana Replicate API is launching at [Sana-API](https://replicate.com/chenxwh/sana). -- (πŸ”₯ New) \[2024/11\] Sana code-base license changed to Apache 2.0. -- (πŸ”₯ New) \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. -- (πŸ”₯ New) \[2024/11\] Training & Inference & Metrics code are released. -- (πŸ”₯ New) \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982). +- βœ… \[2024/11\] All multi-linguistic (Emoji & Chinese & English) SFT models are released: [1.6B-512px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_512px_MultiLing), [1.6B-1024px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing), [600M-512px](https://huggingface.co/Efficient-Large-Model/Sana_600M_512px), [600M-1024px](https://huggingface.co/Efficient-Large-Model/Sana_600M_1024px). The metric performance is shown [here](#performance) +- βœ… \[2024/11\] Sana Replicate API is launching at [Sana-API](https://replicate.com/chenxwh/sana). +- βœ… \[2024/11\] Sana code-base license changed to Apache 2.0. +- βœ… \[2024/11\] 1.6B [Sana models](https://huggingface.co/collections/Efficient-Large-Model/sana-673efba2a57ed99843f11f9e) are released. +- βœ… \[2024/11\] Training & Inference & Metrics code are released. +- βœ… \[2024/11\] Working on [`diffusers`](https://github.com/huggingface/diffusers/pull/9982). - \[2024/10\] [Demo](https://nv-sana.mit.edu/) is released. - \[2024/10\] [DC-AE Code](https://github.com/mit-han-lab/efficientvit/blob/master/applications/dc_ae/README.md) and [weights](https://huggingface.co/collections/mit-han-lab/dc-ae-670085b9400ad7197bb1009b) are released! - \[2024/10\] [Paper](https://arxiv.org/abs/2410.10629) is on Arxiv! @@ -314,18 +315,19 @@ Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md). We will try our best to release -- \[x\] Training code -- \[x\] Inference code -- \[x\] Model zoo -- \[x\] ComfyUI -- \[x\] DC-AE Diffusers -- \[x\] Sana merged in Diffusers(https://github.com/huggingface/diffusers/pull/9982) -- \[x\] LoRA training by [@paul](https://github.com/sayakpaul)(`diffusers`: https://github.com/huggingface/diffusers/pull/10234) -- \[ \] ControlNet (train & inference & models) -- \[ \] 8bit / 4bit Laptop development -- \[ \] Larger model size -- \[ \] Better re-construction F32/F64 VAEs. -- \[ \] **Sana1.5 (Focus on: Human body / Human face / Text rendering / Realism / Efficiency)** +- \[βœ…\] Training code +- \[βœ…\] Inference code +- \[βœ…\] Model zoo +- \[βœ…\] ComfyUI +- \[βœ…\] DC-AE Diffusers +- \[βœ…\] Sana merged in Diffusers(https://github.com/huggingface/diffusers/pull/9982) +- \[βœ…\] LoRA training by [@paul](https://github.com/sayakpaul)(`diffusers`: https://github.com/huggingface/diffusers/pull/10234) +- \[βœ…\] 2K resolution models +- \[πŸ’»\] ControlNet (train & inference & models) +- \[πŸ’»\] 8bit / 4bit Laptop development +- \[πŸ’»\] Larger model size +- \[πŸ’»\] Better re-construction F32/F64 VAEs. +- \[πŸ’»\] **Sana1.5 (Focus on: Human body / Human face / Text rendering / Realism / Efficiency)** # πŸ€—Acknowledgements diff --git a/app/app_sana.py b/app/app_sana.py index 8c9c30b..df69aca 100755 --- a/app/app_sana.py +++ b/app/app_sana.py @@ -386,14 +386,14 @@ def generate( minimum=256, maximum=MAX_IMAGE_SIZE, step=32, - value=1024, + value=args.image_size, ) width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, - value=1024, + value=args.image_size, ) with gr.Row(): flow_dpms_inference_steps = gr.Slider( @@ -401,7 +401,7 @@ def generate( minimum=5, maximum=40, step=1, - value=18, + value=20, ) flow_dpms_guidance_scale = gr.Slider( label="CFG Guidance scale", diff --git a/asset/docs/model_zoo.md b/asset/docs/model_zoo.md index 7368c1e..91e86ba 100644 --- a/asset/docs/model_zoo.md +++ b/asset/docs/model_zoo.md @@ -9,6 +9,7 @@ | Sana-1.6B | 1024px | [Sana_1600M_1024px](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px) | [Efficient-Large-Model/Sana_1600M_1024px_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_diffusers) | fp16/fp32 | - | | Sana-1.6B | 1024px | [Sana_1600M_1024px_MultiLing](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing) | [Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_MultiLing_diffusers) | fp16/fp32 | Multi-Language | | Sana-1.6B | 1024px | [Sana_1600M_1024px_BF16](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16) | [Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers) | **bf16**/fp32 | Multi-Language | +| Sana-1.6B | 2Kpx | [Sana_1600M_2Kpx_BF16](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16) | [Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers](https://huggingface.co/Efficient-Large-Model/Sana_1600M_2Kpx_BF16_diffusers) | **bf16**/fp32 | Multi-Language | ## ❗ 2. Make sure to use correct precision(fp16/bf16/fp32) for training and inference. diff --git a/tools/convert_sana_pag_to_diffusers.py b/tools/convert_sana_pag_to_diffusers.py deleted file mode 100644 index cf550a1..0000000 --- a/tools/convert_sana_pag_to_diffusers.py +++ /dev/null @@ -1,301 +0,0 @@ -#!/usr/bin/env python -from __future__ import annotations - -import argparse -import os -from contextlib import nullcontext - -import torch -from accelerate import init_empty_weights -from diffusers import ( - DCAE, - DCAE_HF, - FlowDPMSolverMultistepScheduler, - FlowMatchEulerDiscreteScheduler, - SanaPAGPipeline, - SanaTransformer2DModel, -) -from diffusers.models.modeling_utils import load_model_dict_into_meta -from diffusers.utils.import_utils import is_accelerate_available -from termcolor import colored -from transformers import AutoModelForCausalLM, AutoTokenizer - -CTX = init_empty_weights if is_accelerate_available else nullcontext - -ckpt_id = "Sana" -# https://github.com/NVlabs/Sana/blob/main/scripts/inference.py - - -def main(args): - all_state_dict = torch.load(args.orig_ckpt_path, map_location=torch.device("cpu")) - state_dict = all_state_dict.pop("state_dict") - converted_state_dict = {} - - # Patch embeddings. - converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") - converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") - - # Caption projection. - converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") - converted_state_dict["caption_projection.linear_1.bias"] = state_dict.pop("y_embedder.y_proj.fc1.bias") - converted_state_dict["caption_projection.linear_2.weight"] = state_dict.pop("y_embedder.y_proj.fc2.weight") - converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") - - # AdaLN-single LN - converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") - converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") - - # Shared norm. - converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") - converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") - - # y norm - converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") - - if args.model_type == "SanaMS_1600M_P1_D20": - layer_num = 20 - flow_shift = 3.0 - elif args.model_type == "SanaMS_600M_P1_D28": - layer_num = 28 - flow_shift = 4.0 - else: - raise ValueError(f"{args.model_type} is not supported.") - - for depth in range(layer_num): - # Transformer blocks. - converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( - f"blocks.{depth}.scale_shift_table" - ) - # Linear Attention is all you need 🀘 - - # Self attention. - q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) - converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q - converted_state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k - converted_state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v - # Projection. - converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict.pop( - f"blocks.{depth}.attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict.pop( - f"blocks.{depth}.attn.proj.bias" - ) - - # Feed-forward. - converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.weight"] = state_dict.pop( - f"blocks.{depth}.mlp.inverted_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.bias"] = state_dict.pop( - f"blocks.{depth}.mlp.inverted_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.weight"] = state_dict.pop( - f"blocks.{depth}.mlp.depth_conv.conv.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.bias"] = state_dict.pop( - f"blocks.{depth}.mlp.depth_conv.conv.bias" - ) - converted_state_dict[f"transformer_blocks.{depth}.ff.point_conv.conv.weight"] = state_dict.pop( - f"blocks.{depth}.mlp.point_conv.conv.weight" - ) - - # Cross-attention. - q = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.weight") - q_bias = state_dict.pop(f"blocks.{depth}.cross_attn.q_linear.bias") - k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.weight"), 2, dim=0) - k_bias, v_bias = torch.chunk(state_dict.pop(f"blocks.{depth}.cross_attn.kv_linear.bias"), 2, dim=0) - - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.weight"] = q - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_q.bias"] = q_bias - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.weight"] = k - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_k.bias"] = k_bias - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.weight"] = v - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_v.bias"] = v_bias - - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.weight"] = state_dict.pop( - f"blocks.{depth}.cross_attn.proj.weight" - ) - converted_state_dict[f"transformer_blocks.{depth}.attn2.to_out.0.bias"] = state_dict.pop( - f"blocks.{depth}.cross_attn.proj.bias" - ) - - # Final block. - converted_state_dict["proj_out.weight"] = state_dict.pop("final_layer.linear.weight") - converted_state_dict["proj_out.bias"] = state_dict.pop("final_layer.linear.bias") - converted_state_dict["scale_shift_table"] = state_dict.pop("final_layer.scale_shift_table") - - # Transformer - with CTX(): - transformer = SanaTransformer2DModel( - num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"], - attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"], - num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"], - cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"], - in_channels=32, - out_channels=32, - num_layers=model_kwargs[args.model_type]["num_layers"], - cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"], - attention_bias=False, - sample_size=32, - patch_size=1, - activation_fn=("silu", "silu", None), - upcast_attention=False, - norm_type="ada_norm_single", - norm_elementwise_affine=False, - norm_eps=1e-6, - use_additional_conditions=False, - caption_channels=2304, - use_caption_norm=True, - caption_norm_scale_factor=0.1, - attention_type="default", - use_pe=False, - expand_ratio=2.5, - ff_bias=(True, True, False), - ff_norm=(None, None, None), - ) - if is_accelerate_available(): - load_model_dict_into_meta(transformer, converted_state_dict) - else: - transformer.load_state_dict(converted_state_dict, strict=True) - - try: - state_dict.pop("y_embedder.y_embedding") - state_dict.pop("pos_embed") - except: - pass - assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" - - num_model_params = sum(p.numel() for p in transformer.parameters()) - print(f"Total number of transformer parameters: {num_model_params}") - - if not args.save_full_pipeline: - print( - colored( - f"Only saving transformer model of {args.model_type}. " - f"Set --save_full_pipeline to save the whole SanaPipeline", - "green", - attrs=["bold"], - ) - ) - transformer.to(weight_dtype).save_pretrained(os.path.join(args.dump_path, "transformer")) - else: - print(colored(f"Saving the whole SanaPAGPipeline containing {args.model_type}", "green", attrs=["bold"])) - # VAE - dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0") - dc_ae_state_dict = dc_ae.state_dict() - dc_ae = DCAE( - in_channels=3, - latent_channels=32, - encoder_width_list=[128, 256, 512, 512, 1024, 1024], - encoder_depth_list=[2, 2, 2, 3, 3, 3], - encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"], - encoder_norm="rms2d", - encoder_act="silu", - downsample_block_type="Conv", - decoder_width_list=[128, 256, 512, 512, 1024, 1024], - decoder_depth_list=[3, 3, 3, 3, 3, 3], - decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"], - decoder_norm="rms2d", - decoder_act="silu", - upsample_block_type="InterpolateConv", - scaling_factor=0.41407, - ) - dc_ae.load_state_dict(dc_ae_state_dict, strict=True) - dc_ae.to(torch.float32).to(device) - - # Text Encoder - text_encoder_model_path = "google/gemma-2-2b-it" - tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path) - tokenizer.padding_side = "right" - text_encoder = ( - AutoModelForCausalLM.from_pretrained(text_encoder_model_path, torch_dtype=torch.bfloat16) - .get_decoder() - .to(device) - ) - - # Scheduler - if args.scheduler_type == "flow-dpm_solver": - scheduler = FlowDPMSolverMultistepScheduler(flow_shift=flow_shift) - elif args.scheduler_type == "flow-euler": - scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) - else: - raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") - - # transformer - transformer.to(device).to(weight_dtype) - - pipe = SanaPAGPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=transformer, - vae=dc_ae, - scheduler=scheduler, - pag_applied_layers="blocks.8", - ) - - image = pipe( - "a dog", - height=1024, - width=1024, - guidance_scale=5.0, - pag_scale=2.0, - )[0] - - image[0].save("sana_pag.png") - - pipe.save_pretrained(args.dump_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "--orig_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." - ) - parser.add_argument( - "--image_size", - default=1024, - type=int, - choices=[512, 1024], - required=False, - help="Image size of pretrained model, 512 or 1024.", - ) - parser.add_argument( - "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] - ) - parser.add_argument( - "--scheduler_type", default="flow-dpm_solver", type=str, choices=["flow-dpm_solver", "flow-euler"] - ) - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") - parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.") - - args = parser.parse_args() - - model_kwargs = { - "SanaMS_1600M_P1_D20": { - "num_attention_heads": 70, - "attention_head_dim": 32, - "num_cross_attention_heads": 20, - "cross_attention_head_dim": 112, - "cross_attention_dim": 2240, - "num_layers": 20, - }, - "SanaMS_600M_P1_D28": { - "num_attention_heads": 36, - "attention_head_dim": 32, - "num_cross_attention_heads": 16, - "cross_attention_head_dim": 72, - "cross_attention_dim": 1152, - "num_layers": 28, - }, - } - - device = "cuda" if torch.cuda.is_available() else "cpu" - weight_dtype = torch.float16 - - main(args) diff --git a/tools/convert_sana_to_diffusers.py b/tools/convert_sana_to_diffusers.py index 8b93282..0239512 100644 --- a/tools/convert_sana_to_diffusers.py +++ b/tools/convert_sana_to_diffusers.py @@ -8,32 +8,60 @@ import torch from accelerate import init_empty_weights from diffusers import ( - DCAE, - DCAE_HF, - FlowDPMSolverMultistepScheduler, + AutoencoderDC, + DPMSolverMultistepScheduler, FlowMatchEulerDiscreteScheduler, SanaPipeline, SanaTransformer2DModel, ) from diffusers.models.modeling_utils import load_model_dict_into_meta from diffusers.utils.import_utils import is_accelerate_available +from huggingface_hub import hf_hub_download, snapshot_download from termcolor import colored from transformers import AutoModelForCausalLM, AutoTokenizer CTX = init_empty_weights if is_accelerate_available else nullcontext -ckpt_id = "Sana" +ckpt_ids = [ + "Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth", + "Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth", + "Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth", + "Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth", + "Efficient-Large-Model/Sana_1600M_1024px/checkpoints/Sana_1600M_1024px.pth", + "Efficient-Large-Model/Sana_1600M_512px/checkpoints/Sana_1600M_512px.pth", + "Efficient-Large-Model/Sana_600M_1024px/checkpoints/Sana_600M_1024px_MultiLing.pth", + "Efficient-Large-Model/Sana_600M_512px/checkpoints/Sana_600M_512px_MultiLing.pth", +] # https://github.com/NVlabs/Sana/blob/main/scripts/inference.py def main(args): - all_state_dict = torch.load(args.orig_ckpt_path, map_location=torch.device("cpu")) + cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub") + + if args.orig_ckpt_path is None or args.orig_ckpt_path in ckpt_ids: + ckpt_id = args.orig_ckpt_path or ckpt_ids[0] + snapshot_download( + repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}", + cache_dir=cache_dir_path, + repo_type="model", + ) + file_path = hf_hub_download( + repo_id=f"{'/'.join(ckpt_id.split('/')[:2])}", + filename=f"{'/'.join(ckpt_id.split('/')[2:])}", + cache_dir=cache_dir_path, + repo_type="model", + ) + else: + file_path = args.orig_ckpt_path + + print(colored(f"Loading checkpoint from {file_path}", "green", attrs=["bold"])) + all_state_dict = torch.load(file_path, weights_only=True) state_dict = all_state_dict.pop("state_dict") converted_state_dict = {} # Patch embeddings. - converted_state_dict["pos_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") - converted_state_dict["pos_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") + converted_state_dict["patch_embed.proj.weight"] = state_dict.pop("x_embedder.proj.weight") + converted_state_dict["patch_embed.proj.bias"] = state_dict.pop("x_embedder.proj.bias") # Caption projection. converted_state_dict["caption_projection.linear_1.weight"] = state_dict.pop("y_embedder.y_proj.fc1.weight") @@ -42,28 +70,23 @@ def main(args): converted_state_dict["caption_projection.linear_2.bias"] = state_dict.pop("y_embedder.y_proj.fc2.bias") # AdaLN-single LN - converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.weight"] = state_dict.pop( - "t_embedder.mlp.0.weight" - ) - converted_state_dict["adaln_single.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") - converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.weight"] = state_dict.pop( - "t_embedder.mlp.2.weight" - ) - converted_state_dict["adaln_single.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.weight"] = state_dict.pop("t_embedder.mlp.0.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_1.bias"] = state_dict.pop("t_embedder.mlp.0.bias") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.weight"] = state_dict.pop("t_embedder.mlp.2.weight") + converted_state_dict["time_embed.emb.timestep_embedder.linear_2.bias"] = state_dict.pop("t_embedder.mlp.2.bias") # Shared norm. - converted_state_dict["adaln_single.linear.weight"] = state_dict.pop("t_block.1.weight") - converted_state_dict["adaln_single.linear.bias"] = state_dict.pop("t_block.1.bias") + converted_state_dict["time_embed.linear.weight"] = state_dict.pop("t_block.1.weight") + converted_state_dict["time_embed.linear.bias"] = state_dict.pop("t_block.1.bias") # y norm converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight") + flow_shift = 3.0 if args.model_type == "SanaMS_1600M_P1_D20": layer_num = 20 - flow_shift = 3.0 elif args.model_type == "SanaMS_600M_P1_D28": layer_num = 28 - flow_shift = 4.0 else: raise ValueError(f"{args.model_type} is not supported.") @@ -72,8 +95,8 @@ def main(args): converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop( f"blocks.{depth}.scale_shift_table" ) - # Linear Attention is all you need 🀘 + # Linear Attention is all you need 🀘 # Self attention. q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0) converted_state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q @@ -88,19 +111,19 @@ def main(args): ) # Feed-forward. - converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.weight"] = state_dict.pop( + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.weight"] = state_dict.pop( f"blocks.{depth}.mlp.inverted_conv.conv.weight" ) - converted_state_dict[f"transformer_blocks.{depth}.ff.inverted_conv.conv.bias"] = state_dict.pop( + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_inverted.bias"] = state_dict.pop( f"blocks.{depth}.mlp.inverted_conv.conv.bias" ) - converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.weight"] = state_dict.pop( + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.weight"] = state_dict.pop( f"blocks.{depth}.mlp.depth_conv.conv.weight" ) - converted_state_dict[f"transformer_blocks.{depth}.ff.depth_conv.conv.bias"] = state_dict.pop( + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_depth.bias"] = state_dict.pop( f"blocks.{depth}.mlp.depth_conv.conv.bias" ) - converted_state_dict[f"transformer_blocks.{depth}.ff.point_conv.conv.weight"] = state_dict.pop( + converted_state_dict[f"transformer_blocks.{depth}.ff.conv_point.weight"] = state_dict.pop( f"blocks.{depth}.mlp.point_conv.conv.weight" ) @@ -132,47 +155,41 @@ def main(args): # Transformer with CTX(): transformer = SanaTransformer2DModel( + in_channels=32, + out_channels=32, num_attention_heads=model_kwargs[args.model_type]["num_attention_heads"], attention_head_dim=model_kwargs[args.model_type]["attention_head_dim"], + num_layers=model_kwargs[args.model_type]["num_layers"], num_cross_attention_heads=model_kwargs[args.model_type]["num_cross_attention_heads"], cross_attention_head_dim=model_kwargs[args.model_type]["cross_attention_head_dim"], - in_channels=32, - out_channels=32, - num_layers=model_kwargs[args.model_type]["num_layers"], cross_attention_dim=model_kwargs[args.model_type]["cross_attention_dim"], + caption_channels=2304, + mlp_ratio=2.5, attention_bias=False, - sample_size=32, + sample_size=args.image_size // 32, patch_size=1, - activation_fn=("silu", "silu", None), - upcast_attention=False, - norm_type="ada_norm_single", norm_elementwise_affine=False, norm_eps=1e-6, - use_additional_conditions=False, - caption_channels=2304, - use_caption_norm=True, - caption_norm_scale_factor=0.1, - attention_type="default", - use_pe=False, - expand_ratio=2.5, - ff_bias=(True, True, False), - ff_norm=(None, None, None), ) + if is_accelerate_available(): load_model_dict_into_meta(transformer, converted_state_dict) else: - transformer.load_state_dict(converted_state_dict, strict=True) + transformer.load_state_dict(converted_state_dict, strict=True, assign=True) try: state_dict.pop("y_embedder.y_embedding") state_dict.pop("pos_embed") - except: - pass + except KeyError: + print("y_embedder.y_embedding or pos_embed not found in the state_dict") + assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}" num_model_params = sum(p.numel() for p in transformer.parameters()) print(f"Total number of transformer parameters: {num_model_params}") + transformer = transformer.to(weight_dtype) + if not args.save_full_pipeline: print( colored( @@ -182,71 +199,55 @@ def main(args): attrs=["bold"], ) ) - transformer.to(weight_dtype).save_pretrained(os.path.join(args.dump_path, "transformer")) + transformer.save_pretrained( + os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant + ) else: print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"])) # VAE - dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0") - dc_ae_state_dict = dc_ae.state_dict() - dc_ae = DCAE( - in_channels=3, - latent_channels=32, - encoder_width_list=[128, 256, 512, 512, 1024, 1024], - encoder_depth_list=[2, 2, 2, 3, 3, 3], - encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"], - encoder_norm="rms2d", - encoder_act="silu", - downsample_block_type="Conv", - decoder_width_list=[128, 256, 512, 512, 1024, 1024], - decoder_depth_list=[3, 3, 3, 3, 3, 3], - decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"], - decoder_norm="rms2d", - decoder_act="silu", - upsample_block_type="InterpolateConv", - scaling_factor=0.41407, - ) - dc_ae.load_state_dict(dc_ae_state_dict, strict=True) - dc_ae.to(torch.float32).to(device) + ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float32) # Text Encoder text_encoder_model_path = "google/gemma-2-2b-it" tokenizer = AutoTokenizer.from_pretrained(text_encoder_model_path) tokenizer.padding_side = "right" - text_encoder = ( - AutoModelForCausalLM.from_pretrained(text_encoder_model_path, torch_dtype=torch.bfloat16) - .get_decoder() - .to(device) - ) + text_encoder = AutoModelForCausalLM.from_pretrained( + text_encoder_model_path, torch_dtype=torch.bfloat16 + ).get_decoder() # Scheduler if args.scheduler_type == "flow-dpm_solver": - scheduler = FlowDPMSolverMultistepScheduler(flow_shift=flow_shift) + scheduler = DPMSolverMultistepScheduler( + flow_shift=flow_shift, + use_flow_sigmas=True, + prediction_type="flow_prediction", + ) elif args.scheduler_type == "flow-euler": scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift) else: raise ValueError(f"Scheduler type {args.scheduler_type} is not supported") - # transformer - transformer.to(device).to(weight_dtype) - pipe = SanaPipeline( tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, - vae=dc_ae, + vae=ae, scheduler=scheduler, ) + pipe.save_pretrained(args.dump_path, safe_serialization=True, max_shard_size="5GB", variant=variant) - image = pipe( - "a dog", - height=1024, - width=1024, - guidance_scale=5.0, - )[0] - image[0].save("sana.png") +DTYPE_MAPPING = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, +} - pipe.save_pretrained(args.dump_path) +VARIANT_MAPPING = { + "fp32": None, + "fp16": "fp16", + "bf16": "bf16", +} if __name__ == "__main__": @@ -259,9 +260,9 @@ def main(args): "--image_size", default=1024, type=int, - choices=[512, 1024], + choices=[512, 1024, 2048], required=False, - help="Image size of pretrained model, 512 or 1024.", + help="Image size of pretrained model, 512 or 1024 or 2048.", ) parser.add_argument( "--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"] @@ -271,6 +272,7 @@ def main(args): ) parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.") + parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") args = parser.parse_args() @@ -294,6 +296,7 @@ def main(args): } device = "cuda" if torch.cuda.is_available() else "cpu" - weight_dtype = torch.float16 + weight_dtype = DTYPE_MAPPING[args.dtype] + variant = VARIANT_MAPPING[args.dtype] main(args)