diff --git a/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py b/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py index 6f850844..d37ac5dd 100644 --- a/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py +++ b/src/invoke_training/_shared/stable_diffusion/model_loading_utils.py @@ -1,3 +1,4 @@ +import logging import os import typing from enum import Enum @@ -14,6 +15,8 @@ from invoke_training._shared.checkpoints.serialization import load_state_dict +HF_VARIANT_FALLBACKS = [None, "fp16"] + class PipelineVersionEnum(Enum): SD = "SD" @@ -21,7 +24,7 @@ class PipelineVersionEnum(Enum): def load_pipeline( - model_name_or_path: str, pipeline_version: PipelineVersionEnum, variant: str | None = None + logger: logging.Logger, model_name_or_path: str, pipeline_version: PipelineVersionEnum, variant: str | None = None ) -> typing.Union[StableDiffusionPipeline, StableDiffusionXLPipeline]: """Load a Stable Diffusion pipeline from disk. @@ -45,15 +48,38 @@ def load_pipeline( if os.path.isfile(model_name_or_path): return pipeline_class.from_single_file(model_name_or_path, load_safety_checker=False) - return pipeline_class.from_pretrained( - model_name_or_path, - safety_checker=None, - variant=variant, - requires_safety_checker=False, - ) + variants_to_try = [variant] + [v for v in HF_VARIANT_FALLBACKS if v != variant] + + pipeline = None + for variant_to_try in variants_to_try: + if variant_to_try != variant: + logger.warning(f"Trying fallback variant '{variant_to_try}'.") + try: + pipeline = pipeline_class.from_pretrained( + model_name_or_path, + safety_checker=None, + variant=variant_to_try, + requires_safety_checker=False, + ) + except OSError as e: + if "no file named" in str(e): + # Ok; we'll try the variant fallbacks. + logger.warning( + f"Failed to load pipeline '{model_name_or_path}' with variant '{variant_to_try}'. Error: {e}." + ) + else: + raise + + if pipeline is not None: + break + + if pipeline is None: + raise RuntimeError(f"Failed to load pipeline '{model_name_or_path}'.") + return pipeline def load_models_sd( + logger: logging.Logger, model_name_or_path: str, hf_variant: str | None = None, base_embeddings: dict[str, str] = None, @@ -65,7 +91,10 @@ def load_models_sd( base_embeddings = base_embeddings or {} pipeline: StableDiffusionPipeline = load_pipeline( - model_name_or_path=model_name_or_path, pipeline_version=PipelineVersionEnum.SD, variant=hf_variant + logger=logger, + model_name_or_path=model_name_or_path, + pipeline_version=PipelineVersionEnum.SD, + variant=hf_variant, ) for token, embedding_path in base_embeddings.items(): @@ -104,6 +133,7 @@ def load_models_sd( def load_models_sdxl( + logger: logging.Logger, model_name_or_path: str, hf_variant: str | None = None, vae_model: str | None = None, @@ -124,7 +154,10 @@ def load_models_sdxl( base_embeddings = base_embeddings or {} pipeline: StableDiffusionXLPipeline = load_pipeline( - model_name_or_path=model_name_or_path, pipeline_version=PipelineVersionEnum.SDXL, variant=hf_variant + logger=logger, + model_name_or_path=model_name_or_path, + pipeline_version=PipelineVersionEnum.SDXL, + variant=hf_variant, ) for token, embedding_path in base_embeddings.items(): diff --git a/src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py b/src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py index 37ea8abd..968a137d 100644 --- a/src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py +++ b/src/invoke_training/pipelines/_experimental/sd_dpo_lora/train.py @@ -229,6 +229,7 @@ def train(config: SdDirectPreferenceOptimizationLoraConfig, callbacks: list[Pipe logger.info("Loading models.") tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( + logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, base_embeddings=config.base_embeddings, diff --git a/src/invoke_training/pipelines/stable_diffusion/lora/train.py b/src/invoke_training/pipelines/stable_diffusion/lora/train.py index 4e154cb2..9d44bd7b 100644 --- a/src/invoke_training/pipelines/stable_diffusion/lora/train.py +++ b/src/invoke_training/pipelines/stable_diffusion/lora/train.py @@ -292,6 +292,7 @@ def train(config: SdLoraConfig, callbacks: list[PipelineCallbacks] | None = None logger.info("Loading models.") tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( + logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, base_embeddings=config.base_embeddings, diff --git a/src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py b/src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py index 34c7ecbe..7ac43c90 100644 --- a/src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py +++ b/src/invoke_training/pipelines/stable_diffusion/textual_inversion/train.py @@ -165,7 +165,7 @@ def train(config: SdTextualInversionConfig, callbacks: list[PipelineCallbacks] | logger.info("Loading models.") tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( - model_name_or_path=config.model, hf_variant=config.hf_variant, dtype=weight_dtype + logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, dtype=weight_dtype ) placeholder_tokens, placeholder_token_ids = _initialize_placeholder_tokens( diff --git a/src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py b/src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py index b762a0fe..3bec7c08 100644 --- a/src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py +++ b/src/invoke_training/pipelines/stable_diffusion_xl/lora/train.py @@ -360,6 +360,7 @@ def train(config: SdxlLoraConfig, callbacks: list[PipelineCallbacks] | None = No logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( + logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, diff --git a/src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py b/src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py index 27f3125c..53309bff 100644 --- a/src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py +++ b/src/invoke_training/pipelines/stable_diffusion_xl/lora_and_textual_inversion/train.py @@ -153,6 +153,7 @@ def train(config: SdxlLoraAndTextualInversionConfig, callbacks: list[PipelineCal logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( + logger=logger, model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, diff --git a/src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py b/src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py index a9ffa507..1bf30237 100644 --- a/src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py +++ b/src/invoke_training/pipelines/stable_diffusion_xl/textual_inversion/train.py @@ -191,7 +191,11 @@ def train(config: SdxlTextualInversionConfig, callbacks: list[PipelineCallbacks] logger.info("Loading models.") tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( - model_name_or_path=config.model, hf_variant=config.hf_variant, vae_model=config.vae_model, dtype=weight_dtype + logger=logger, + model_name_or_path=config.model, + hf_variant=config.hf_variant, + vae_model=config.vae_model, + dtype=weight_dtype, ) placeholder_tokens, placeholder_token_ids_1, placeholder_token_ids_2 = _initialize_placeholder_tokens( diff --git a/tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py b/tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py index 1e125e6c..4bca504b 100644 --- a/tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py +++ b/tests/invoke_training/_shared/stable_diffusion/test_model_loading_utils.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path import pytest @@ -6,7 +7,10 @@ from invoke_training._shared.stable_diffusion.model_loading_utils import load_models_sd, load_models_sdxl -from .ti_embedding_checkpoint_fixture import sdv1_embedding_path, sdxl_embedding_path # noqa: F401 +from .ti_embedding_checkpoint_fixture import ( # noqa: F401 + sdv1_embedding_path, + sdxl_embedding_path, +) @pytest.mark.loads_model @@ -14,6 +18,7 @@ def test_load_models_sd(sdv1_embedding_path): # noqa: F811 model_name = "runwayml/stable-diffusion-v1-5" tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( + logger=logging.getLogger(__name__), model_name_or_path=model_name, hf_variant="fp16", base_embeddings={"special_test_token": str(sdv1_embedding_path)}, @@ -34,6 +39,7 @@ def test_load_models_sdxl(sdxl_embedding_path: Path): # noqa: F811 model_name = "stabilityai/stable-diffusion-xl-base-1.0" tokenizer_1, tokenizer_2, noise_scheduler, text_encoder_1, text_encoder_2, vae, unet = load_models_sdxl( + logger=logging.getLogger(__name__), model_name_or_path=model_name, hf_variant="fp16", base_embeddings={"special_test_token": str(sdxl_embedding_path)}, diff --git a/tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py b/tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py index fd6d910f..29117696 100644 --- a/tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py +++ b/tests/invoke_training/_shared/stable_diffusion/test_textual_inversion.py @@ -31,7 +31,7 @@ def test_expand_placeholder_token_raises_on_invalid_num_vectors(): @pytest.mark.loads_model def test_initialize_placeholder_tokens_from_initializer_token(): tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( - model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" + logger=logging.getLogger(__name__), model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" ) initializer_token = "dog" @@ -59,7 +59,7 @@ def test_initialize_placeholder_tokens_from_initializer_token(): @pytest.mark.loads_model def test_initialize_placeholder_tokens_from_initial_phrase(): tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( - model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" + logger=logging.getLogger(__name__), model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" ) initial_phrase = "little brown dog" @@ -86,7 +86,7 @@ def test_initialize_placeholder_tokens_from_initial_phrase(): @pytest.mark.loads_model def test_initialize_placeholder_tokens_from_initial_embedding(sdv1_embedding_path: Path): # noqa: F811 tokenizer, noise_scheduler, text_encoder, vae, unet = load_models_sd( - model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" + logger=logging.getLogger(__name__), model_name_or_path="runwayml/stable-diffusion-v1-5", hf_variant="fp16" ) placeholder_token = "custom_token"