diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index fffb09e6549..85de3039c65 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -26,6 +26,7 @@ ConditioningFieldData, SDXLConditioningInfo, ) +from invokeai.backend.stable_diffusion.extensions import LoRAPatcherExt from invokeai.backend.util.devices import TorchDevice # unconditioned: Optional[torch.Tensor] @@ -82,9 +83,10 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (model_state_dict, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora_text_encoder( - text_encoder, + LoRAPatcherExt.static_patch_model( + model=text_encoder, loras=_lora_loader(), + prefix="lora_te_", model_state_dict=model_state_dict, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. @@ -177,8 +179,8 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (state_dict, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora( - text_encoder, + LoRAPatcherExt.static_patch_model( + model=text_encoder, loras=_lora_loader(), prefix=lora_prefix, model_state_dict=state_dict, diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 7ccf9068939..b14250d672d 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1,23 +1,20 @@ # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) import inspect from contextlib import ExitStack -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torchvision import torchvision.transforms as T from diffusers.configuration_utils import ConfigMixin -from diffusers.models.adapter import T2IAdapter from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.schedulers.scheduling_dpmsolver_sde import DPMSolverSDEScheduler from diffusers.schedulers.scheduling_tcd import TCDScheduler from diffusers.schedulers.scheduling_utils import SchedulerMixin as Scheduler from pydantic import field_validator from torchvision.transforms.functional import resize as tv_resize -from transformers import CLIPVisionModelWithProjection from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.fields import ( ConditioningField, @@ -33,30 +30,32 @@ from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.controlnet_utils import prepare_control_image -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import BaseModelType -from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless -from invokeai.backend.stable_diffusion.diffusers_pipeline import ( - ControlNetData, - StableDiffusionGeneratorPipeline, - T2IAdapterData, -) +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.diffusers_pipeline import StableDiffusionBackend from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, - IPAdapterConditioningInfo, - IPAdapterData, Range, SDXLConditioningInfo, TextConditioningData, TextConditioningRegions, ) +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 +from invokeai.backend.stable_diffusion.extensions import ( + ControlNetExt, + FreeUExt, + InpaintExt, + IPAdapterExt, + LoRAPatcherExt, + PipelineIntermediateState, + PreviewExt, + RescaleCFGExt, + SeamlessExt, + T2IAdapterExt, +) +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.hotfixes import ControlNetModel from invokeai.backend.util.mask import to_standard_float_mask from invokeai.backend.util.silence_warnings import SilenceWarnings @@ -88,10 +87,6 @@ def get_scheduler( scheduler_config["noise_sampler_seed"] = seed scheduler = scheduler_class.from_config(scheduler_config) - - # hack copied over from generate.py - if not hasattr(scheduler, "uses_inpainting_model"): - scheduler.uses_inpainting_model = lambda: False assert isinstance(scheduler, Scheduler) return scheduler @@ -314,12 +309,12 @@ def get_conditioning_data( context: InvocationContext, positive_conditioning_field: Union[ConditioningField, list[ConditioningField]], negative_conditioning_field: Union[ConditioningField, list[ConditioningField]], - unet: UNet2DConditionModel, latent_height: int, latent_width: int, + device: torch.device, + dtype: torch.dtype, cfg_scale: float | list[float], steps: int, - cfg_rescale_multiplier: float, ) -> TextConditioningData: # Normalize positive_conditioning_field and negative_conditioning_field to lists. cond_list = positive_conditioning_field @@ -330,10 +325,10 @@ def get_conditioning_data( uncond_list = [uncond_list] cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( - cond_list, context, unet.device, unet.dtype + cond_list, context, device, dtype ) uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( - uncond_list, context, unet.device, unet.dtype + uncond_list, context, device, dtype ) cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( @@ -341,14 +336,14 @@ def get_conditioning_data( masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, - dtype=unet.dtype, + dtype=dtype, ) uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, - dtype=unet.dtype, + dtype=dtype, ) if isinstance(cfg_scale, list): @@ -360,42 +355,16 @@ def get_conditioning_data( uncond_regions=uncond_regions, cond_regions=cond_regions, guidance_scale=cfg_scale, - guidance_rescale_multiplier=cfg_rescale_multiplier, ) return conditioning_data @staticmethod - def create_pipeline( - unet: UNet2DConditionModel, - scheduler: Scheduler, - ) -> StableDiffusionGeneratorPipeline: - class FakeVae: - class FakeVaeConfig: - def __init__(self) -> None: - self.block_out_channels = [0] - - def __init__(self) -> None: - self.config = FakeVae.FakeVaeConfig() - - return StableDiffusionGeneratorPipeline( - vae=FakeVae(), # TODO: oh... - text_encoder=None, - tokenizer=None, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - - @staticmethod - def prep_control_data( + def parse_controlnet_field( + exit_stack: ExitStack, context: InvocationContext, control_input: ControlField | list[ControlField] | None, - latents_shape: List[int], - exit_stack: ExitStack, - do_classifier_free_guidance: bool = True, - ) -> list[ControlNetData] | None: + ext_manager: ExtensionsManager, + ) -> None: # Normalize control_input to a list. control_list: list[ControlField] if isinstance(control_input, ControlField): @@ -407,191 +376,92 @@ def prep_control_data( else: raise ValueError(f"Unexpected control_input type: {type(control_input)}") - if len(control_list) == 0: - return None - - # Assuming fixed dimensional scaling of LATENT_SCALE_FACTOR. - _, _, latent_height, latent_width = latents_shape - control_height_resize = latent_height * LATENT_SCALE_FACTOR - control_width_resize = latent_width * LATENT_SCALE_FACTOR - - controlnet_data: list[ControlNetData] = [] for control_info in control_list: - control_model = exit_stack.enter_context(context.models.load(control_info.control_model)) - assert isinstance(control_model, ControlNetModel) - - control_image_field = control_info.image - input_image = context.images.get_pil(control_image_field.image_name) - # self.image.image_type, self.image.image_name - # FIXME: still need to test with different widths, heights, devices, dtypes - # and add in batch_size, num_images_per_prompt? - # and do real check for classifier_free_guidance? - # prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) - control_image = prepare_control_image( - image=input_image, - do_classifier_free_guidance=do_classifier_free_guidance, - width=control_width_resize, - height=control_height_resize, - # batch_size=batch_size * num_images_per_prompt, - # num_images_per_prompt=num_images_per_prompt, - device=control_model.device, - dtype=control_model.dtype, - control_mode=control_info.control_mode, - resize_mode=control_info.resize_mode, - ) - control_item = ControlNetData( - model=control_model, - image_tensor=control_image, - weight=control_info.control_weight, - begin_step_percent=control_info.begin_step_percent, - end_step_percent=control_info.end_step_percent, - control_mode=control_info.control_mode, - # any resizing needed should currently be happening in prepare_control_image(), - # but adding resize_mode to ControlNetData in case needed in the future - resize_mode=control_info.resize_mode, + model = exit_stack.enter_context(context.models.load(control_info.control_model)) + ext_manager.add_extension( + ControlNetExt( + model=model, + image=context.images.get_pil(control_info.image.image_name), + weight=control_info.control_weight, + begin_step_percent=control_info.begin_step_percent, + end_step_percent=control_info.end_step_percent, + control_mode=control_info.control_mode, + resize_mode=control_info.resize_mode, + priority=100, + ) ) - controlnet_data.append(control_item) # MultiControlNetModel has been refactored out, just need list[ControlNetData] - return controlnet_data - - def prep_ip_adapter_image_prompts( - self, + @staticmethod + def parse_ip_adapter_field( + exit_stack: ExitStack, context: InvocationContext, ip_adapters: List[IPAdapterField], - ) -> List[Tuple[torch.Tensor, torch.Tensor]]: - """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" - image_prompts = [] - for single_ip_adapter in ip_adapters: - with context.models.load(single_ip_adapter.ip_adapter_model) as ip_adapter_model: - assert isinstance(ip_adapter_model, IPAdapter) - image_encoder_model_info = context.models.load(single_ip_adapter.image_encoder_model) - # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. - single_ipa_image_fields = single_ip_adapter.image - if not isinstance(single_ipa_image_fields, list): - single_ipa_image_fields = [single_ipa_image_fields] - - single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] - with image_encoder_model_info as image_encoder_model: - assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) - # Get image embeddings from CLIP and ImageProjModel. - image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds( - single_ipa_images, image_encoder_model - ) - image_prompts.append((image_prompt_embeds, uncond_image_prompt_embeds)) + ext_manager: ExtensionsManager, + ) -> None: + if ip_adapters is None: + return - return image_prompts + if not isinstance(ip_adapters, list): + ip_adapters = [ip_adapters] - def prep_ip_adapter_data( - self, - context: InvocationContext, - ip_adapters: List[IPAdapterField], - image_prompts: List[Tuple[torch.Tensor, torch.Tensor]], - exit_stack: ExitStack, - latent_height: int, - latent_width: int, - dtype: torch.dtype, - ) -> Optional[List[IPAdapterData]]: - """If IP-Adapter is enabled, then this function loads the requisite models and adds the image prompt conditioning data.""" - ip_adapter_data_list = [] - for single_ip_adapter, (image_prompt_embeds, uncond_image_prompt_embeds) in zip( - ip_adapters, image_prompts, strict=True - ): - ip_adapter_model = exit_stack.enter_context(context.models.load(single_ip_adapter.ip_adapter_model)) + for single_ip_adapter in ip_adapters: + # `single_ip_adapter.image` could be a list or a single ImageField. Normalize to a list here. + single_ipa_image_fields = single_ip_adapter.image + if not isinstance(single_ipa_image_fields, list): + single_ipa_image_fields = [single_ipa_image_fields] + + single_ipa_images = [context.images.get_pil(image.image_name) for image in single_ipa_image_fields] mask_field = single_ip_adapter.mask mask = context.tensors.load(mask_field.tensor_name) if mask_field is not None else None - mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) - ip_adapter_data_list.append( - IPAdapterData( - ip_adapter_model=ip_adapter_model, + ext_manager.add_extension( + IPAdapterExt( + node_context=context, + exit_stack=exit_stack, + model_id=single_ip_adapter.ip_adapter_model, + image_encoder_model_id=single_ip_adapter.image_encoder_model, + images=single_ipa_images, weight=single_ip_adapter.weight, target_blocks=single_ip_adapter.target_blocks, begin_step_percent=single_ip_adapter.begin_step_percent, end_step_percent=single_ip_adapter.end_step_percent, - ip_adapter_conditioning=IPAdapterConditioningInfo(image_prompt_embeds, uncond_image_prompt_embeds), mask=mask, + priority=100, ) ) - return ip_adapter_data_list if len(ip_adapter_data_list) > 0 else None - - def run_t2i_adapters( - self, + @staticmethod + def parse_t2i_field( + exit_stack: ExitStack, context: InvocationContext, - t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], - latents_shape: list[int], - do_classifier_free_guidance: bool, - ) -> Optional[list[T2IAdapterData]]: - if t2i_adapter is None: - return None - - # Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField. - if isinstance(t2i_adapter, T2IAdapterField): - t2i_adapter = [t2i_adapter] - - if len(t2i_adapter) == 0: - return None - - t2i_adapter_data = [] - for t2i_adapter_field in t2i_adapter: - t2i_adapter_model_config = context.models.get_config(t2i_adapter_field.t2i_adapter_model.key) - t2i_adapter_loaded_model = context.models.load(t2i_adapter_field.t2i_adapter_model) - image = context.images.get_pil(t2i_adapter_field.image.image_name) - - # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. - if t2i_adapter_model_config.base == BaseModelType.StableDiffusion1: - max_unet_downscale = 8 - elif t2i_adapter_model_config.base == BaseModelType.StableDiffusionXL: - max_unet_downscale = 4 - else: - raise ValueError(f"Unexpected T2I-Adapter base model type: '{t2i_adapter_model_config.base}'.") - - t2i_adapter_model: T2IAdapter - with t2i_adapter_loaded_model as t2i_adapter_model: - total_downscale_factor = t2i_adapter_model.total_downscale_factor - - # Resize the T2I-Adapter input image. - # We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the - # result will match the latent image's dimensions after max_unet_downscale is applied. - t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor - t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor - - # Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare - # a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the - # T2I-Adapter model. - # - # Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many - # of the same requirements (e.g. preserving binary masks during resize). - t2i_image = prepare_control_image( - image=image, - do_classifier_free_guidance=False, - width=t2i_input_width, - height=t2i_input_height, - num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict - device=t2i_adapter_model.device, - dtype=t2i_adapter_model.dtype, - resize_mode=t2i_adapter_field.resize_mode, - ) - - adapter_state = t2i_adapter_model(t2i_image) - - if do_classifier_free_guidance: - for idx, value in enumerate(adapter_state): - adapter_state[idx] = torch.cat([value] * 2, dim=0) + t2i_adapters: Optional[Union[T2IAdapterField, list[T2IAdapterField]]], + ext_manager: ExtensionsManager, + ) -> None: + if t2i_adapters is None: + return - t2i_adapter_data.append( - T2IAdapterData( - adapter_state=adapter_state, + # Handle the possibility that t2i_adapters could be a list or a single T2IAdapterField. + if isinstance(t2i_adapters, T2IAdapterField): + t2i_adapters = [t2i_adapters] + + for t2i_adapter_field in t2i_adapters: + ext_manager.add_extension( + T2IAdapterExt( + node_context=context, + exit_stack=exit_stack, + model_id=t2i_adapter_field.t2i_adapter_model, + image=context.images.get_pil(t2i_adapter_field.image.image_name), + adapter_state=None, weight=t2i_adapter_field.weight, begin_step_percent=t2i_adapter_field.begin_step_percent, end_step_percent=t2i_adapter_field.end_step_percent, + resize_mode=t2i_adapter_field.resize_mode, + priority=100, ) ) - return t2i_adapter_data - # original idea by https://github.com/AmericanPresidentJimmyCarter # TODO: research more for second order schedulers timesteps @staticmethod @@ -710,139 +580,134 @@ def prepare_noise_and_latents( @torch.no_grad() @SilenceWarnings() # This quenches the NSFW nag from diffusers. def invoke(self, context: InvocationContext) -> LatentsOutput: - seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + with ExitStack() as exit_stack: + ext_manager = ExtensionsManager() - mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) + device = TorchDevice.choose_torch_device() + dtype = TorchDevice.choose_torch_dtype() - # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, - # below. Investigate whether this is appropriate. - t2i_adapter_data = self.run_t2i_adapters( - context, - self.t2i_adapter, - latents.shape, - do_classifier_free_guidance=True, - ) - - ip_adapters: List[IPAdapterField] = [] - if self.ip_adapter is not None: - # ip_adapter could be a list or a single IPAdapterField. Normalize to a list here. - if isinstance(self.ip_adapter, list): - ip_adapters = self.ip_adapter - else: - ip_adapters = [self.ip_adapter] - - # If there are IP adapters, the following line runs the adapters' CLIPVision image encoders to return - # a series of image conditioning embeddings. This is being done here rather than in the - # big model context below in order to use less VRAM on low-VRAM systems. - # The image prompts are then passed to prep_ip_adapter_data(). - image_prompts = self.prep_ip_adapter_image_prompts(context=context, ip_adapters=ip_adapters) - - # get the unet's config so that we can pass the base to sd_step_callback() - unet_config = context.models.get_config(self.unet.unet.key) - - def step_callback(state: PipelineIntermediateState) -> None: - context.util.sd_step_callback(state, unet_config.base) - - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self.unet.loras: - lora_info = context.models.load(lora.lora) - assert isinstance(lora_info.model, LoRAModelRaw) - yield (lora_info.model, lora.weight) - del lora_info - return - - unet_info = context.models.load(self.unet.unet) - assert isinstance(unet_info.model, UNet2DConditionModel) - with ( - ExitStack() as exit_stack, - unet_info.model_on_device() as (model_state_dict, unet), - ModelPatcher.apply_freeu(unet, self.unet.freeu_config), - set_seamless(unet, self.unet.seamless_axes), # FIXME - # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet( - unet, - loras=_lora_loader(), - model_state_dict=model_state_dict, - ), - ): - assert isinstance(unet, UNet2DConditionModel) - latents = latents.to(device=unet.device, dtype=unet.dtype) + seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents) + latents = latents.to(device=device, dtype=dtype) if noise is not None: - noise = noise.to(device=unet.device, dtype=unet.dtype) - if mask is not None: - mask = mask.to(device=unet.device, dtype=unet.dtype) - if masked_latents is not None: - masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=seed, - ) - - pipeline = self.create_pipeline(unet, scheduler) + noise = noise.to(device=device, dtype=dtype) _, _, latent_height, latent_width = latents.shape + conditioning_data = self.get_conditioning_data( context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, - unet=unet, - latent_height=latent_height, - latent_width=latent_width, cfg_scale=self.cfg_scale, steps=self.steps, - cfg_rescale_multiplier=self.cfg_rescale_multiplier, - ) - - controlnet_data = self.prep_control_data( - context=context, - control_input=self.control, - latents_shape=latents.shape, - # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) - do_classifier_free_guidance=True, - exit_stack=exit_stack, + latent_height=latent_height, + latent_width=latent_width, + device=device, + dtype=dtype, ) - ip_adapter_data = self.prep_ip_adapter_data( + scheduler = get_scheduler( context=context, - ip_adapters=ip_adapters, - image_prompts=image_prompts, - exit_stack=exit_stack, - latent_height=latent_height, - latent_width=latent_width, - dtype=unet.dtype, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, ) timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler( scheduler, - device=unet.device, + seed=seed, + device=device, steps=self.steps, denoising_start=self.denoising_start, denoising_end=self.denoising_end, - seed=seed, ) - result_latents = pipeline.latents_from_embeddings( + denoise_ctx = DenoiseContext( latents=latents, timesteps=timesteps, init_timestep=init_timestep, noise=noise, seed=seed, - mask=mask, - masked_latents=masked_latents, - is_gradient_mask=gradient_mask, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, - control_data=controlnet_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - callback=step_callback, + unet=None, + scheduler=scheduler, ) + # get the unet's config so that we can pass the base to sd_step_callback() + unet_config = context.models.get_config(self.unet.unet.key) + + ### inpaint + mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) + if ( + mask is not None or unet_config.variant == "inpaint" # ModelVariantType.Inpaint + ): + ext_manager.add_extension(InpaintExt(mask, masked_latents, is_gradient_mask, priority=200)) + + ### preview + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + + ext_manager.add_extension(PreviewExt(step_callback, priority=99999)) + + ### cfg rescale + if self.cfg_rescale_multiplier > 0: + ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier, priority=100)) + + ### seamless + if self.unet.seamless_axes: + ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes, priority=100)) + + ### freeu + if self.unet.freeu_config: + ext_manager.add_extension(FreeUExt(self.unet.freeu_config, priority=100)) + + ### lora + if self.unet.loras: + ext_manager.add_extension( + LoRAPatcherExt( + node_context=context, + loras=self.unet.loras, + prefix="lora_unet_", + priority=100, + ) + ) + + ### tiled denoise + # ext_manager.add_extension( + # TiledDenoiseExt( + # tile_width=1024, + # tile_height=1024, + # tile_overlap=32, + # priority=100, + # ) + # ) + + # later will be like: + # for extension_field in self.extensions: + # ext = extension_field.to_extension(exit_stack, context) + # ext_manager.add_extension(ext) + self.parse_t2i_field(exit_stack, context, self.t2i_adapter, ext_manager) + self.parse_controlnet_field(exit_stack, context, self.control, ext_manager) + self.parse_ip_adapter_field(exit_stack, context, self.ip_adapter, ext_manager) + + # ext: t2i/ip adapter + ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + # ext: controlnet + ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 - result_latents = result_latents.to("cpu") + result_latents = result_latents.to("cpu") # TODO: detach? TorchDevice.empty_cache() name = context.tensors.save(tensor=result_latents) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index cc8a9c44a3f..38d0cf017ab 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -24,7 +24,7 @@ from invokeai.app.invocations.model import VAEField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.stable_diffusion import set_seamless +from invokeai.backend.stable_diffusion.extensions import SeamlessExt from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params from invokeai.backend.util.devices import TorchDevice @@ -59,7 +59,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: vae_info = context.models.load(self.vae.vae) assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny)) - with set_seamless(vae_info.model, self.vae.seamless_axes), vae_info as vae: + with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) latents = latents.to(vae.device) if self.fp32: diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 5d408a4df7c..868dc8d88bc 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -1,10 +1,8 @@ -import copy from contextlib import ExitStack -from typing import Iterator, Tuple +from typing import Optional, Union import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from diffusers.schedulers.scheduling_utils import SchedulerMixin from pydantic import field_validator from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation @@ -19,38 +17,28 @@ LatentsField, UIType, ) +from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.model import UNetField from invokeai.app.invocations.primitives import LatentsOutput +from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState -from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import ( - MultiDiffusionPipeline, - MultiDiffusionRegionConditioning, +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.diffusers_pipeline import StableDiffusionBackend +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 +from invokeai.backend.stable_diffusion.extensions import ( + FreeUExt, + LoRAPatcherExt, + PipelineIntermediateState, + PreviewExt, + RescaleCFGExt, + SeamlessExt, + TiledDenoiseExt, ) +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES -from invokeai.backend.tiles.tiles import ( - calc_tiles_min_overlap, -) -from invokeai.backend.tiles.utils import TBLR from invokeai.backend.util.devices import TorchDevice -def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData: - """Crop a ControlNetData object to a region.""" - # Create a shallow copy of the control_data object. - control_data_copy = copy.copy(control_data) - # The ControlNet reference image is the only attribute that needs to be cropped. - control_data_copy.image_tensor = control_data.image_tensor[ - :, - :, - latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR, - latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR, - ] - return control_data_copy - - @invocation( "tiled_multi_diffusion_denoise_latents", title="Tiled Multi-Diffusion Denoise Latents", @@ -126,6 +114,18 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): default=None, input=Input.Connection, ) + t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]] = InputField( + description=FieldDescriptions.t2i_adapter, + title="T2I-Adapter", + default=None, + input=Input.Connection, + ) + ip_adapter: Optional[Union[IPAdapterField, list[IPAdapterField]]] = InputField( + description=FieldDescriptions.ip_adapter, + title="IP-Adapter", + default=None, + input=Input.Connection, + ) @field_validator("cfg_scale") def ge_one(cls, v: list[float] | float) -> list[float] | float: @@ -139,141 +139,133 @@ def ge_one(cls, v: list[float] | float) -> list[float] | float: raise ValueError("cfg_scale must be greater than 1") return v - @staticmethod - def create_pipeline( - unet: UNet2DConditionModel, - scheduler: SchedulerMixin, - ) -> MultiDiffusionPipeline: - # TODO(ryand): Get rid of this FakeVae hack. - class FakeVae: - class FakeVaeConfig: - def __init__(self) -> None: - self.block_out_channels = [0] - - def __init__(self) -> None: - self.config = FakeVae.FakeVaeConfig() - - return MultiDiffusionPipeline( - vae=FakeVae(), - text_encoder=None, - tokenizer=None, - unet=unet, - scheduler=scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) - @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: - # Convert tile image-space dimensions to latent-space dimensions. - latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR - latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR - latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR - - seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents) - _, _, latent_height, latent_width = latents.shape - - # Calculate the tile locations to cover the latent-space image. - tiles = calc_tiles_min_overlap( - image_height=latent_height, - image_width=latent_width, - tile_height=latent_tile_height, - tile_width=latent_tile_width, - min_overlap=latent_tile_overlap, - ) + with ExitStack() as exit_stack: + ext_manager = ExtensionsManager() - # Get the unet's config so that we can pass the base to sd_step_callback(). - unet_config = context.models.get_config(self.unet.unet.key) + device = TorchDevice.choose_torch_device() + dtype = TorchDevice.choose_torch_dtype() - def step_callback(state: PipelineIntermediateState) -> None: - context.util.sd_step_callback(state, unet_config.base) - - # Prepare an iterator that yields the UNet's LoRA models and their weights. - def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self.unet.loras: - lora_info = context.models.load(lora.lora) - assert isinstance(lora_info.model, LoRAModelRaw) - yield (lora_info.model, lora.weight) - del lora_info - - # Load the UNet model. - unet_info = context.models.load(self.unet.unet) - - with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()): - assert isinstance(unet, UNet2DConditionModel) - latents = latents.to(device=unet.device, dtype=unet.dtype) + seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents) + latents = latents.to(device=device, dtype=dtype) if noise is not None: - noise = noise.to(device=unet.device, dtype=unet.dtype) - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=seed, - ) - pipeline = self.create_pipeline(unet=unet, scheduler=scheduler) + noise = noise.to(device=device, dtype=dtype) + + _, _, latent_height, latent_width = latents.shape - # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, - unet=unet, - latent_height=latent_tile_height, - latent_width=latent_tile_width, cfg_scale=self.cfg_scale, steps=self.steps, - cfg_rescale_multiplier=self.cfg_rescale_multiplier, + latent_height=latent_height, + latent_width=latent_width, + device=device, + dtype=dtype, ) - controlnet_data = DenoiseLatentsInvocation.prep_control_data( + scheduler = get_scheduler( context=context, - control_input=self.control, - latents_shape=list(latents.shape), - # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) - do_classifier_free_guidance=True, - exit_stack=exit_stack, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + seed=seed, ) - # Split the controlnet_data into tiles. - # controlnet_data_tiles[t][c] is the c'th control data for the t'th tile. - controlnet_data_tiles: list[list[ControlNetData]] = [] - for tile in tiles: - tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []] - controlnet_data_tiles.append(tile_controlnet_data) - - # Prepare the MultiDiffusionRegionConditioning list. - multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = [] - for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True): - multi_diffusion_conditioning.append( - MultiDiffusionRegionConditioning( - region=tile, - text_conditioning_data=conditioning_data, - control_data=tile_controlnet_data, - ) - ) - timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler( scheduler, - device=unet.device, + seed=seed, + device=device, steps=self.steps, denoising_start=self.denoising_start, denoising_end=self.denoising_end, - seed=seed, ) - # Run Multi-Diffusion denoising. - result_latents = pipeline.multi_diffusion_denoise( - multi_diffusion_conditioning=multi_diffusion_conditioning, - target_overlap=latent_tile_overlap, + denoise_ctx = DenoiseContext( latents=latents, - scheduler_step_kwargs=scheduler_step_kwargs, - noise=noise, timesteps=timesteps, init_timestep=init_timestep, - callback=step_callback, + noise=noise, + seed=seed, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + unet=None, + scheduler=scheduler, ) + # get the unet's config so that we can pass the base to sd_step_callback() + unet_config = context.models.get_config(self.unet.unet.key) + + ### inpaint + # mask, masked_latents, is_gradient_mask = self.prep_inpaint_mask(context, latents) + # if mask is not None or unet_config.variant == "inpaint": # ModelVariantType.Inpaint: # is_inpainting_model(unet): + # ext_manager.add_extension(InpaintExt(mask, masked_latents, is_gradient_mask, priority=200)) + + ### preview + def step_callback(state: PipelineIntermediateState) -> None: + context.util.sd_step_callback(state, unet_config.base) + + ext_manager.add_extension(PreviewExt(step_callback, priority=99999)) + + ### cfg rescale + if self.cfg_rescale_multiplier > 0: + ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier, priority=100)) + + ### seamless + if self.unet.seamless_axes: + ext_manager.add_extension(SeamlessExt(self.unet.seamless_axes, priority=100)) + + ### freeu + if self.unet.freeu_config: + ext_manager.add_extension(FreeUExt(self.unet.freeu_config, priority=100)) + + ### lora + if self.unet.loras: + ext_manager.add_extension( + LoRAPatcherExt( + node_context=context, + loras=self.unet.loras, + prefix="lora_unet_", + priority=100, + ) + ) + + ### tiled denoise + ext_manager.add_extension( + TiledDenoiseExt( + tile_width=self.tile_width, + tile_height=self.tile_height, + tile_overlap=self.tile_overlap, + priority=100, + ) + ) + + # later will be like: + # for extension_field in self.extensions: + # ext = extension_field.to_extension(exit_stack, context) + # ext_manager.add_extension(ext) + DenoiseLatentsInvocation.parse_t2i_field(exit_stack, context, self.t2i_adapter, ext_manager) + DenoiseLatentsInvocation.parse_controlnet_field(exit_stack, context, self.control, ext_manager) + # TODO: works fine with tiled too? + DenoiseLatentsInvocation.parse_ip_adapter_field(exit_stack, context, self.ip_adapter, ext_manager) + + # ext: t2i/ip adapter + ext_manager.modifiers.pre_unet_load(denoise_ctx, ext_manager) + + unet_info = context.models.load(self.unet.unet) + assert isinstance(unet_info.model, UNet2DConditionModel) + with ( + unet_info.model_on_device() as (model_state_dict, unet), + # ext: controlnet + ext_manager.patch_attention_processor(unet, CustomAttnProcessor2_0), + # ext: freeu, seamless, ip adapter, lora + ext_manager.patch_unet(model_state_dict, unet), + ): + sd_backend = StableDiffusionBackend(unet, scheduler) + denoise_ctx.unet = unet + result_latents = sd_backend.latents_from_embeddings(denoise_ctx, ext_manager) + result_latents = result_latents.to("cpu") # TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important. TorchDevice.empty_cache() diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 8c7a62c3719..a02f7d924f6 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,20 +5,17 @@ import pickle from contextlib import contextmanager -from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Tuple, Union import numpy as np import torch -from diffusers import OnnxRuntimeModel, UNet2DConditionModel +from diffusers import OnnxRuntimeModel from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from invokeai.app.shared.models import FreeUConfig from invokeai.backend.lora import LoRAModelRaw -from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw -from invokeai.backend.util.devices import TorchDevice """ loras = [ @@ -34,140 +31,6 @@ # TODO: rename smth like ModelPatcher and add TI method? class ModelPatcher: - @staticmethod - def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: - assert "." not in lora_key - - if not lora_key.startswith(prefix): - raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") - - module = model - module_key = "" - key_parts = lora_key[len(prefix) :].split("_") - - submodule_name = key_parts.pop(0) - - while len(key_parts) > 0: - try: - module = module.get_submodule(submodule_name) - module_key += "." + submodule_name - submodule_name = key_parts.pop(0) - except Exception: - submodule_name += "_" + key_parts.pop(0) - - module = module.get_submodule(submodule_name) - module_key = (module_key + "." + submodule_name).lstrip(".") - - return (module_key, module) - - @classmethod - @contextmanager - def apply_lora_unet( - cls, - unet: UNet2DConditionModel, - loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, - ) -> Generator[None, None, None]: - with cls.apply_lora( - unet, - loras=loras, - prefix="lora_unet_", - model_state_dict=model_state_dict, - ): - yield - - @classmethod - @contextmanager - def apply_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: Iterator[Tuple[LoRAModelRaw, float]], - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, - ) -> Generator[None, None, None]: - with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): - yield - - @classmethod - @contextmanager - def apply_lora( - cls, - model: AnyModel, - loras: Iterator[Tuple[LoRAModelRaw, float]], - prefix: str, - model_state_dict: Optional[Dict[str, torch.Tensor]] = None, - ) -> Generator[None, None, None]: - """ - Apply one or more LoRAs to a model. - - :param model: The model to patch. - :param loras: An iterator that returns the LoRA to patch in and its patch weight. - :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. - :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. - """ - original_weights = {} - try: - with torch.no_grad(): - for lora, lora_weight in loras: - # assert lora.device.type == "cpu" - for layer_key, layer in lora.layers.items(): - if not layer_key.startswith(prefix): - continue - - # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This - # should be improved in the following ways: - # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a - # LoRA model is applied. - # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the - # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA - # weights to have valid keys. - assert isinstance(model, torch.nn.Module) - module_key, module = cls._resolve_lora_key(model, layer_key, prefix) - - # All of the LoRA weight calculations will be done on the same device as the module weight. - # (Performance will be best if this is a CUDA device.) - device = module.weight.device - dtype = module.weight.dtype - - if module_key not in original_weights: - if model_state_dict is not None: # we were provided with the CPU copy of the state dict - original_weights[module_key] = model_state_dict[module_key + ".weight"] - else: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) - - layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 - - # We intentionally move to the target device first, then cast. Experimentally, this was found to - # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the - # same thing in a single call to '.to(...)'. - layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device)) - layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device)) - # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA - # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) - layer.to( - device=TorchDevice.CPU_DEVICE, - non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE), - ) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - if module.weight.shape != layer_weight.shape: - # TODO: debug on lycoris - assert hasattr(layer_weight, "reshape") - layer_weight = layer_weight.reshape(module.weight.shape) - - assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! - module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device)) - - yield # wait for context manager exit - - finally: - assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() - with torch.no_grad(): - for module_key, weight in original_weights.items(): - model.get_submodule(module_key).weight.copy_( - weight, non_blocking=TorchDevice.get_non_blocking(weight.device) - ) - @classmethod @contextmanager def apply_ti( @@ -284,27 +147,6 @@ def apply_clip_skip( while len(skipped_layers) > 0: text_encoder.text_model.encoder.layers.append(skipped_layers.pop()) - @classmethod - @contextmanager - def apply_freeu( - cls, - unet: UNet2DConditionModel, - freeu_config: Optional[FreeUConfig] = None, - ) -> None: - did_apply_freeu = False - try: - assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? - if freeu_config is not None: - unet.enable_freeu(b1=freeu_config.b1, b2=freeu_config.b2, s1=freeu_config.s1, s2=freeu_config.s2) - did_apply_freeu = True - - yield - - finally: - assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? - if did_apply_freeu: - unet.disable_freeu() - class ONNXModelPatcher: @classmethod diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 440cb4410ba..b0af9e91b2c 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -2,16 +2,11 @@ Initialization file for the invokeai.backend.stable_diffusion package """ -from invokeai.backend.stable_diffusion.diffusers_pipeline import ( # noqa: F401 +# TODO: rename/move +from invokeai.backend.stable_diffusion.extensions.preview import ( # noqa: F401 PipelineIntermediateState, - StableDiffusionGeneratorPipeline, ) -from invokeai.backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent # noqa: F401 -from invokeai.backend.stable_diffusion.seamless import set_seamless # noqa: F401 __all__ = [ "PipelineIntermediateState", - "StableDiffusionGeneratorPipeline", - "InvokeAIDiffuserComponent", - "set_seamless", ] diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py new file mode 100644 index 00000000000..b56f0959481 --- /dev/null +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +import torch +from diffusers import UNet2DConditionModel +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData + + +@dataclass +class UNetKwargs: + sample: torch.Tensor + timestep: Union[torch.Tensor, float, int] + encoder_hidden_states: torch.Tensor + + class_labels: Optional[torch.Tensor] = None + timestep_cond: Optional[torch.Tensor] = None + attention_mask: Optional[torch.Tensor] = None + cross_attention_kwargs: Optional[Dict[str, Any]] = None + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None + mid_block_additional_residual: Optional[torch.Tensor] = None + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None + encoder_attention_mask: Optional[torch.Tensor] = None + # return_dict: bool = True + + +@dataclass +class DenoiseContext: + latents: torch.Tensor + scheduler_step_kwargs: dict[str, Any] + conditioning_data: TextConditioningData + noise: Optional[torch.Tensor] + seed: int + timesteps: torch.Tensor + init_timestep: torch.Tensor + + scheduler: SchedulerMixin + unet: Optional[UNet2DConditionModel] = None + + orig_latents: Optional[torch.Tensor] = None + step_index: Optional[int] = None + timestep: Optional[torch.Tensor] = None + unet_kwargs: Optional[UNetKwargs] = None + step_output: Optional[SchedulerOutput] = None + + latent_model_input: Optional[torch.Tensor] = None + conditioning_mode: Optional[str] = None + negative_noise_pred: Optional[torch.Tensor] = None + positive_noise_pred: Optional[torch.Tensor] = None + noise_pred: Optional[torch.Tensor] = None + + extra: dict = field(default_factory=dict) + + def __delattr__(self, name: str): + setattr(self, name, None) diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index ee464f73e1f..c26c29a86b7 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -1,74 +1,18 @@ from __future__ import annotations -import math -from contextlib import nullcontext -from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Union - -import einops import PIL.Image -import psutil import torch import torchvision.transforms as T -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin -from diffusers.utils.import_utils import is_xformers_available -from pydantic import Field -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput +from tqdm.auto import tqdm from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData -from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent -from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData -from invokeai.backend.util.attention import auto_detect_slice_size -from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.hotfixes import ControlNetModel - - -@dataclass -class PipelineIntermediateState: - step: int - order: int - total_steps: int - timestep: int - latents: torch.Tensor - predicted_original: Optional[torch.Tensor] = None - - -@dataclass -class AddsMaskGuidance: - mask: torch.Tensor - mask_latents: torch.Tensor - scheduler: SchedulerMixin - noise: torch.Tensor - is_gradient_mask: bool - - def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return self.apply_mask(latents, t) - - def apply_mask(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - batch_size = latents.size(0) - mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) - if t.dim() == 0: - # some schedulers expect t to be one-dimensional. - # TODO: file diffusers bug about inconsistency? - t = einops.repeat(t, "-> batch", batch=batch_size) - # Noise shouldn't be re-randomized between steps here. The multistep schedulers - # get very confused about what is happening from step to step when we do that. - mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t) - # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? - # mask_latents = self.scheduler.scale_model_input(mask_latents, t) - mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) - if self.is_gradient_mask: - threshhold = (t.item()) / self.scheduler.config.num_train_timesteps - mask_bool = mask > threshhold # I don't know when mask got inverted, but it did - masked_input = torch.where(mask_bool, latents, mask_latents) - else: - masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) - return masked_input +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, UNetKwargs + +# TODO: remove and fix imports +from invokeai.backend.stable_diffusion.extensions import PipelineIntermediateState # noqa: F401 +from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager def trim_to_multiple_of(*args, multiple_of=8): @@ -95,498 +39,185 @@ def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool = Tr return tensor -def is_inpainting_model(unet: UNet2DConditionModel): - return unet.conv_in.in_channels == 9 - - -@dataclass -class ControlNetData: - model: ControlNetModel = Field(default=None) - image_tensor: torch.Tensor = Field(default=None) - weight: Union[float, List[float]] = Field(default=1.0) - begin_step_percent: float = Field(default=0.0) - end_step_percent: float = Field(default=1.0) - control_mode: str = Field(default="balanced") - resize_mode: str = Field(default="just_resize") - - -@dataclass -class T2IAdapterData: - """A structure containing the information required to apply conditioning from a single T2I-Adapter model.""" - - adapter_state: dict[torch.Tensor] = Field() - weight: Union[float, list[float]] = Field(default=1.0) - begin_step_percent: float = Field(default=0.0) - end_step_percent: float = Field(default=1.0) - - -class StableDiffusionGeneratorPipeline(StableDiffusionPipeline): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the - library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) - - Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline. - Hopefully future versions of diffusers provide access to more of these functions so that we don't - need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384 - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. - text_encoder ([`CLIPTextModel`]): - Frozen text-encoder. Stable Diffusion uses the text portion of - [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically - the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. - tokenizer (`CLIPTokenizer`): - Tokenizer of class - [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). - unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. - feature_extractor ([`CLIPFeatureExtractor`]): - Model that extracts features from generated images to be used as inputs for the `safety_checker`. - """ - +class StableDiffusionBackend: def __init__( self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: Optional[StableDiffusionSafetyChecker], - feature_extractor: Optional[CLIPFeatureExtractor], - requires_safety_checker: bool = False, + scheduler: SchedulerMixin, ): - super().__init__( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - requires_safety_checker=requires_safety_checker, - ) - - self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward) - - def _adjust_memory_efficient_attention(self, latents: torch.Tensor): - """ - if xformers is available, use it, otherwise use sliced attention. - """ + self.unet = unet + self.scheduler = scheduler config = get_config() - if config.attention_type == "xformers": - self.enable_xformers_memory_efficient_attention() - return - elif config.attention_type == "sliced": - slice_size = config.attention_slice_size - if slice_size == "auto": - slice_size = auto_detect_slice_size(latents) - elif slice_size == "balanced": - slice_size = "auto" - self.enable_attention_slicing(slice_size=slice_size) - return - elif config.attention_type == "normal": - self.disable_attention_slicing() - return - elif config.attention_type == "torch-sdp": - if hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return - else: - raise Exception("torch-sdp attention slicing not available") - - # the remainder if this code is called when attention_type=='auto' - if self.unet.device.type == "cuda": - if is_xformers_available(): - self.enable_xformers_memory_efficient_attention() - return - elif hasattr(torch.nn.functional, "scaled_dot_product_attention"): - # diffusers enables sdp automatically - return - - if self.unet.device.type == "cpu" or self.unet.device.type == "mps": - mem_free = psutil.virtual_memory().free - elif self.unet.device.type == "cuda": - mem_free, _ = torch.cuda.mem_get_info(TorchDevice.normalize(self.unet.device)) - else: - raise ValueError(f"unrecognized device {self.unet.device}") - # input tensor of [1, 4, h/8, w/8] - # output tensor of [16, (h/8 * w/8), (h/8 * w/8)] - bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 - max_size_required_for_baddbmm = ( - 16 - * latents.size(dim=2) - * latents.size(dim=3) - * latents.size(dim=2) - * latents.size(dim=3) - * bytes_per_element_needed_for_baddbmm_duplication - ) - if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): # 3.3 / 4.0 is from old Invoke code - self.enable_attention_slicing(slice_size="max") - elif torch.backends.mps.is_available(): - # diffusers recommends always enabling for mps - self.enable_attention_slicing(slice_size="max") - else: - self.disable_attention_slicing() + self.sequential_guidance = config.sequential_guidance - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False): - raise Exception("Should not be called") + def latents_from_embeddings(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + if ctx.init_timestep.shape[0] == 0: + return ctx.latents - def add_inpainting_channels_to_latents( - self, latents: torch.Tensor, masked_ref_image_latents: torch.Tensor, inpainting_mask: torch.Tensor - ): - """Given a `latents` tensor, adds the mask and image latents channels required for inpainting. + ctx.orig_latents = ctx.latents.clone() - Standard (non-inpainting) SD UNet models expect an input with shape (N, 4, H, W). Inpainting models expect an - input of shape (N, 9, H, W). The 9 channels are defined as follows: - - Channel 0-3: The latents being denoised. - - Channel 4: The mask indicating which parts of the image are being inpainted. - - Channel 5-8: The latent representation of the masked reference image being inpainted. + if ctx.noise is not None: + batch_size = ctx.latents.shape[0] + # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers + ctx.latents = ctx.scheduler.add_noise(ctx.latents, ctx.noise, ctx.init_timestep.expand(batch_size)) - This function assumes that the same mask and base image should apply to all items in the batch. - """ - # Validate assumptions about input tensor shapes. - batch_size, latent_channels, latent_height, latent_width = latents.shape - assert latent_channels == 4 - assert list(masked_ref_image_latents.shape) == [1, 4, latent_height, latent_width] - assert list(inpainting_mask.shape) == [1, 1, latent_height, latent_width] + # if no work to do, return latents + if ctx.timesteps.shape[0] == 0: + return ctx.latents - # Repeat original_image_latents and inpainting_mask to match the latents batch size. - original_image_latents = masked_ref_image_latents.expand(batch_size, -1, -1, -1) - inpainting_mask = inpainting_mask.expand(batch_size, -1, -1, -1) + # ext: inpaint[pre_denoise_loop, priority=normal] (maybe init, but not sure if it needed) + # ext: preview[pre_denoise_loop, priority=low] + ext_manager.modifiers.pre_denoise_loop(ctx) - # Concatenate along the channel dimension. - return torch.cat([latents, inpainting_mask, original_image_latents], dim=1) + for ctx.step_index, ctx.timestep in enumerate(tqdm(ctx.timesteps)): # noqa: B020 + # ext: inpaint (apply mask to latents on non-inpaint models) + ext_manager.modifiers.pre_step(ctx) - def latents_from_embeddings( - self, - latents: torch.Tensor, - scheduler_step_kwargs: dict[str, Any], - conditioning_data: TextConditioningData, - noise: Optional[torch.Tensor], - seed: int, - timesteps: torch.Tensor, - init_timestep: torch.Tensor, - callback: Callable[[PipelineIntermediateState], None], - control_data: list[ControlNetData] | None = None, - ip_adapter_data: Optional[list[IPAdapterData]] = None, - t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - mask: Optional[torch.Tensor] = None, - masked_latents: Optional[torch.Tensor] = None, - is_gradient_mask: bool = False, - ) -> torch.Tensor: - """Denoise the latents. - - Args: - latents: The latent-space image to denoise. - - If we are inpainting, this is the initial latent image before noise has been added. - - If we are generating a new image, this should be initialized to zeros. - - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner). - scheduler_step_kwargs: kwargs forwarded to the scheduler.step() method. - conditioning_data: Text conditionging data. - noise: Noise used for two purposes: - 1. Used by the scheduler to noise the initial `latents` before denoising. - 2. Used to noise the `masked_latents` when inpainting. - `noise` should be None if the `latents` tensor has already been noised. - seed: The seed used to generate the noise for the denoising process. - HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the - same noise used earlier in the pipeline. This should really be handled in a clearer way. - timesteps: The timestep schedule for the denoising process. - init_timestep: The first timestep in the schedule. This is used to determine the initial noise level, so - should be populated if you want noise applied *even* if timesteps is empty. - callback: A callback function that is called to report progress during the denoising process. - control_data: ControlNet data. - ip_adapter_data: IP-Adapter data. - t2i_adapter_data: T2I-Adapter data. - mask: A mask indicating which parts of the image are being inpainted. The presence of mask is used to - determine whether we are inpainting or not. `mask` should have the same spatial dimensions as the - `latents` tensor. - TODO(ryand): Check and document the expected dtype, range, and values used to represent - foreground/background. - masked_latents: A latent-space representation of a masked inpainting reference image. This tensor is only - used if an *inpainting* model is being used i.e. this tensor is not used when inpainting with a standard - SD UNet model. - is_gradient_mask: A flag indicating whether `mask` is a gradient mask or not. - """ - if init_timestep.shape[0] == 0: - return latents + # ext: tiles? [override: step] + ctx.step_output = ext_manager.overrides.step(self.step, ctx, ext_manager) - orig_latents = latents.clone() + # ext: inpaint[post_step, priority=high] (apply mask to preview on non-inpaint models) + # ext: preview[post_step, priority=low] + ext_manager.modifiers.post_step(ctx) - batch_size = latents.shape[0] - batched_init_timestep = init_timestep.expand(batch_size) + ctx.latents = ctx.step_output.prev_sample - # noise can be None if the latents have already been noised (e.g. when running the SDXL refiner). - if noise is not None: - # TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with - # full noise. Investigate the history of why this got commented out. - # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers - latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) - - self._adjust_memory_efficient_attention(latents) - - # Handle mask guidance (a.k.a. inpainting). - mask_guidance: AddsMaskGuidance | None = None - if mask is not None and not is_inpainting_model(self.unet): - # We are doing inpainting, since a mask is provided, but we are not using an inpainting model, so we will - # apply mask guidance to the latents. - - # 'noise' might be None if the latents have already been noised (e.g. when running the SDXL refiner). - # We still need noise for inpainting, so we generate it from the seed here. - if noise is None: - noise = torch.randn( - orig_latents.shape, - dtype=torch.float32, - device="cpu", - generator=torch.Generator(device="cpu").manual_seed(seed), - ).to(device=orig_latents.device, dtype=orig_latents.dtype) - - mask_guidance = AddsMaskGuidance( - mask=mask, - mask_latents=orig_latents, - scheduler=self.scheduler, - noise=noise, - is_gradient_mask=is_gradient_mask, - ) - - use_ip_adapter = ip_adapter_data is not None - use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None - ) - unet_attention_patcher = None - attn_ctx = nullcontext() - - if use_ip_adapter or use_regional_prompting: - ip_adapters: Optional[List[UNetIPAdapterData]] = ( - [{"ip_adapter": ipa.ip_adapter_model, "target_blocks": ipa.target_blocks} for ipa in ip_adapter_data] - if use_ip_adapter - else None - ) - unet_attention_patcher = UNetAttentionPatcher(ip_adapters) - attn_ctx = unet_attention_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) - - with attn_ctx: - callback( - PipelineIntermediateState( - step=-1, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) - ) - - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - step_output = self.step( - t=batched_t, - latents=latents, - conditioning_data=conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=mask_guidance, - mask=mask, - masked_latents=masked_latents, - control_data=control_data, - ip_adapter_data=ip_adapter_data, - t2i_adapter_data=t2i_adapter_data, - ) - latents = step_output.prev_sample - predicted_original = getattr(step_output, "pred_original_sample", None) - - callback( - PipelineIntermediateState( - step=i, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(t), - latents=latents, - predicted_original=predicted_original, - ) - ) - - # restore unmasked part after the last step is completed - # in-process masking happens before each step - if mask is not None: - if is_gradient_mask: - latents = torch.where(mask > 0, latents, orig_latents) - else: - latents = torch.lerp( - orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype) - ) - - return latents + # ext: inpaint[post_denoise_loop] (restore unmasked part) + ext_manager.modifiers.post_denoise_loop(ctx) + return ctx.latents @torch.inference_mode() - def step( - self, - t: torch.Tensor, - latents: torch.Tensor, - conditioning_data: TextConditioningData, - step_index: int, - total_step_count: int, - scheduler_step_kwargs: dict[str, Any], - mask_guidance: AddsMaskGuidance | None, - mask: torch.Tensor | None, - masked_latents: torch.Tensor | None, - control_data: list[ControlNetData] | None = None, - ip_adapter_data: Optional[list[IPAdapterData]] = None, - t2i_adapter_data: Optional[list[T2IAdapterData]] = None, - ): - # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value - timestep = t[0] - - # Handle masked image-to-image (a.k.a inpainting). - if mask_guidance is not None: - # NOTE: This is intentionally done *before* self.scheduler.scale_model_input(...). - latents = mask_guidance(latents, timestep) - - # TODO: should this scaling happen here or inside self._unet_forward? - # i.e. before or after passing it to InvokeAIDiffuserComponent - latent_model_input = self.scheduler.scale_model_input(latents, timestep) - - # Handle ControlNet(s) - down_block_additional_residuals = None - mid_block_additional_residual = None - if control_data is not None: - down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step( - control_data=control_data, - sample=latent_model_input, - timestep=timestep, - step_index=step_index, - total_step_count=total_step_count, - conditioning_data=conditioning_data, - ) - - # Handle T2I-Adapter(s) - down_intrablock_additional_residuals = None - if t2i_adapter_data is not None: - accum_adapter_state = None - for single_t2i_adapter_data in t2i_adapter_data: - # Determine the T2I-Adapter weights for the current denoising step. - first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count) - last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count) - t2i_adapter_weight = ( - single_t2i_adapter_data.weight[step_index] - if isinstance(single_t2i_adapter_data.weight, list) - else single_t2i_adapter_data.weight - ) - if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step: - # If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0 - # so it has no effect. - t2i_adapter_weight = 0.0 - - # Apply the t2i_adapter_weight, and accumulate. - if accum_adapter_state is None: - # Handle the first T2I-Adapter. - accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state] - else: - # Add to the previous adapter states. - for idx, value in enumerate(single_t2i_adapter_data.adapter_state): - accum_adapter_state[idx] += value * t2i_adapter_weight - - down_intrablock_additional_residuals = accum_adapter_state - - # Handle inpainting models. - if is_inpainting_model(self.unet): - # NOTE: These calls to add_inpainting_channels_to_latents(...) are intentionally done *after* - # self.scheduler.scale_model_input(...) so that the scaling is not applied to the mask or reference image - # latents. - if mask is not None: - if masked_latents is None: - raise ValueError("Source image required for inpaint mask when inpaint model used!") - latent_model_input = self.add_inpainting_channels_to_latents( - latents=latent_model_input, masked_ref_image_latents=masked_latents, inpainting_mask=mask - ) - else: - # We are using an inpainting model, but no mask was provided, so we are not really "inpainting". - # We generate a global mask and empty original image so that we can still generate in this - # configuration. - # TODO(ryand): Should we just raise an exception here instead? I can't think of a use case for wanting - # to do this. - # TODO(ryand): If we decide that there is a good reason to keep this, then we should generate the 'fake' - # mask and original image once rather than on every denoising step. - latent_model_input = self.add_inpainting_channels_to_latents( - latents=latent_model_input, - masked_ref_image_latents=torch.zeros_like(latent_model_input[:1]), - inpainting_mask=torch.ones_like(latent_model_input[:1, :1]), - ) - - uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( - sample=latent_model_input, - timestep=t, # TODO: debug how handled batched and non batched timesteps - step_index=step_index, - total_step_count=total_step_count, - conditioning_data=conditioning_data, - ip_adapter_data=ip_adapter_data, - down_block_additional_residuals=down_block_additional_residuals, # for ControlNet - mid_block_additional_residual=mid_block_additional_residual, # for ControlNet - down_intrablock_additional_residuals=down_intrablock_additional_residuals, # for T2I-Adapter - ) + def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> SchedulerOutput: + ctx.latent_model_input = ctx.scheduler.scale_model_input(ctx.latents, ctx.timestep) - guidance_scale = conditioning_data.guidance_scale - if isinstance(guidance_scale, list): - guidance_scale = guidance_scale[step_index] + if self.sequential_guidance: + conditioning_call = self._apply_standard_conditioning_sequentially + else: + conditioning_call = self._apply_standard_conditioning + + # not sure if here needed override + ctx.negative_noise_pred, ctx.positive_noise_pred = conditioning_call(ctx, ext_manager) - noise_pred = self.invokeai_diffuser._combine(uc_noise_pred, c_noise_pred, guidance_scale) - guidance_rescale_multiplier = conditioning_data.guidance_rescale_multiplier - if guidance_rescale_multiplier > 0: - noise_pred = self._rescale_cfg( - noise_pred, - c_noise_pred, - guidance_rescale_multiplier, - ) + # ext: override combine_noise + ctx.noise_pred = ext_manager.overrides.combine_noise(self.combine_noise, ctx) + + # ext: cfg_rescale [modify_noise_prediction] + ext_manager.modifiers.modify_noise_prediction(ctx) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) - - # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting - # again. - if mask_guidance is not None: - # Apply the mask to any "denoised" or "pred_original_sample" fields. - if hasattr(step_output, "denoised"): - step_output.pred_original_sample = mask_guidance(step_output.denoised, self.scheduler.timesteps[-1]) - elif hasattr(step_output, "pred_original_sample"): - step_output.pred_original_sample = mask_guidance( - step_output.pred_original_sample, self.scheduler.timesteps[-1] - ) - else: - step_output.pred_original_sample = mask_guidance(latents, self.scheduler.timesteps[-1]) + step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.scheduler_step_kwargs) + + # del locals + del ctx.latent_model_input + del ctx.negative_noise_pred + del ctx.positive_noise_pred + del ctx.noise_pred return step_output @staticmethod - def _rescale_cfg(total_noise_pred, pos_noise_pred, multiplier=0.7): - """Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf.""" - ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True) - ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True) + def combine_noise(ctx: DenoiseContext) -> torch.Tensor: + guidance_scale = ctx.conditioning_data.guidance_scale + if isinstance(guidance_scale, list): + guidance_scale = guidance_scale[ctx.step_index] - x_rescaled = total_noise_pred * (ro_pos / ro_cfg) - x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred - return x_final + return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) + # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) - def _unet_forward( - self, - latents, - t, - text_embeddings, - cross_attention_kwargs: Optional[dict[str, Any]] = None, - **kwargs, - ): - """predict the noise residual""" - # First three args should be positional, not keywords, so torch hooks can see them. - return self.unet( - latents, - t, - text_embeddings, - cross_attention_kwargs=cross_attention_kwargs, - **kwargs, - ).sample + def _apply_standard_conditioning( + self, ctx: DenoiseContext, ext_manager: ExtensionsManager + ) -> tuple[torch.Tensor, torch.Tensor]: + """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at + the cost of higher memory usage. + """ + + ctx.unet_kwargs = UNetKwargs( + sample=torch.cat([ctx.latent_model_input] * 2), + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + ), + ) + + ctx.conditioning_mode = "both" + ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, ctx.conditioning_mode) + + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.modifiers.pre_unet_forward(ctx) + + # ext: inpaint [pre_unet_forward, priority=low] + # or + # ext: inpaint [override: unet_forward] + both_results = self._unet_forward(**vars(ctx.unet_kwargs)) + negative_next_x, positive_next_x = both_results.chunk(2) + # del locals + del ctx.unet_kwargs + del ctx.conditioning_mode + return negative_next_x, positive_next_x + + def _apply_standard_conditioning_sequentially(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of + slower execution speed. + """ + + ################### + # Negative pass + ################### + + ctx.unet_kwargs = UNetKwargs( + sample=ctx.latent_model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + ), + ) + + ctx.conditioning_mode = "negative" + ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "negative") + + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.modifiers.pre_unet_forward(ctx) + + # ext: inpaint [pre_unet_forward, priority=low] + # or + # ext: inpaint [override: unet_forward] + negative_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) + + del ctx.unet_kwargs + del ctx.conditioning_mode + # TODO: gc.collect() ? + + ################### + # Positive pass + ################### + + ctx.unet_kwargs = UNetKwargs( + sample=ctx.latent_model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / len(ctx.timesteps), # ctx.total_steps, + ), + ) + + ctx.conditioning_mode = "positive" + ctx.conditioning_data.to_unet_kwargs(ctx.unet_kwargs, "positive") + + # ext: controlnet/ip/t2i [pre_unet_forward] + ext_manager.modifiers.pre_unet_forward(ctx) + + # ext: inpaint [pre_unet_forward, priority=low] + # or + # ext: inpaint [override: unet_forward] + positive_next_x = self._unet_forward(**vars(ctx.unet_kwargs)) + + del ctx.unet_kwargs + del ctx.conditioning_mode + # TODO: gc.collect() ? + + return negative_next_x, positive_next_x + + def _unet_forward(self, **kwargs) -> torch.Tensor: + return self.unet(**kwargs).sample diff --git a/invokeai/backend/stable_diffusion/diffusion/__init__.py b/invokeai/backend/stable_diffusion/diffusion/__init__.py index 712542f79cf..fec1068f073 100644 --- a/invokeai/backend/stable_diffusion/diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/diffusion/__init__.py @@ -1,7 +1,3 @@ """ Initialization file for invokeai.models.diffusion """ - -from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import ( - InvokeAIDiffuserComponent, # noqa: F401 -) diff --git a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py index 85950a01df5..ad1683780ac 100644 --- a/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py @@ -5,6 +5,7 @@ import torch from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData @dataclass @@ -103,7 +104,6 @@ def __init__( uncond_regions: Optional[TextConditioningRegions], cond_regions: Optional[TextConditioningRegions], guidance_scale: Union[float, List[float]], - guidance_rescale_multiplier: float = 0, ): self.uncond_text = uncond_text self.cond_text = cond_text @@ -114,10 +114,131 @@ def __init__( # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # images that are closely linked to the text `prompt`, usually at the expense of lower image quality. self.guidance_scale = guidance_scale - # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. - # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). - self.guidance_rescale_multiplier = guidance_rescale_multiplier def is_sdxl(self): assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo) + + def to_unet_kwargs(self, unet_kwargs, conditioning_mode): + if conditioning_mode == "both": + encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( + self.uncond_text.embeds, self.cond_text.embeds + ) + elif conditioning_mode == "positive": + encoder_hidden_states = self.cond_text.embeds + encoder_attention_mask = None + else: # elif conditioning_mode == "negative": + encoder_hidden_states = self.uncond_text.embeds + encoder_attention_mask = None + + unet_kwargs.encoder_hidden_states = encoder_hidden_states + unet_kwargs.encoder_attention_mask = encoder_attention_mask + + if self.is_sdxl(): + if conditioning_mode == "negative": + added_cond_kwargs = dict( # noqa: C408 + text_embeds=self.cond_text.pooled_embeds, + time_ids=self.cond_text.add_time_ids, + ) + elif conditioning_mode == "positive": + added_cond_kwargs = dict( # noqa: C408 + text_embeds=self.uncond_text.pooled_embeds, + time_ids=self.uncond_text.add_time_ids, + ) + else: # elif conditioning_mode == "both": + added_cond_kwargs = dict( # noqa: C408 + text_embeds=torch.cat( + [ + # TODO: how to pad? just by zeros? or even truncate? + self.uncond_text.pooled_embeds, + self.cond_text.pooled_embeds, + ], + ), + time_ids=torch.cat( + [ + self.uncond_text.add_time_ids, + self.cond_text.add_time_ids, + ], + ), + ) + + unet_kwargs.added_cond_kwargs = added_cond_kwargs + + if self.cond_regions is not None or self.uncond_regions is not None: + # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings + # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems + # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of + # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly + # awkward to handle both standard conditioning and sequential conditioning further up the stack. + + _tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions + _, _, h, w = _tmp_regions.masks.shape + dtype = self.cond_text.embeds.dtype + device = self.cond_text.embeds.device + + regions = [] + for c, r in [ + (self.uncond_text, self.uncond_regions), + (self.cond_text, self.cond_regions), + ]: + if r is None: + # Create a dummy mask and range for text conditioning that doesn't have region masks. + r = TextConditioningRegions( + masks=torch.ones((1, 1, h, w), dtype=dtype), + ranges=[Range(start=0, end=c.embeds.shape[1])], + ) + regions.append(r) + + if unet_kwargs.cross_attention_kwargs is None: + unet_kwargs.cross_attention_kwargs = {} + + unet_kwargs.cross_attention_kwargs.update( + regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype), + ) + + def _concat_conditionings_for_batch(self, unconditioning, conditioning): + def _pad_conditioning(cond, target_len, encoder_attention_mask): + conditioning_attention_mask = torch.ones( + (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype + ) + + if cond.shape[1] < max_len: + conditioning_attention_mask = torch.cat( + [ + conditioning_attention_mask, + torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), + ], + dim=1, + ) + + cond = torch.cat( + [ + cond, + torch.zeros( + (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), + device=cond.device, + dtype=cond.dtype, + ), + ], + dim=1, + ) + + if encoder_attention_mask is None: + encoder_attention_mask = conditioning_attention_mask + else: + encoder_attention_mask = torch.cat( + [ + encoder_attention_mask, + conditioning_attention_mask, + ] + ) + + return cond, encoder_attention_mask + + encoder_attention_mask = None + if unconditioning.shape[1] != conditioning.shape[1]: + max_len = max(unconditioning.shape[1], conditioning.shape[1]) + unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) + conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) + + return torch.cat([unconditioning, conditioning]), encoder_attention_mask diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 1334313fe6e..464b871990d 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, cast +from typing import Optional, cast import torch import torch.nn.functional as F @@ -25,19 +25,14 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): - Regional prompt attention """ - def __init__( - self, - ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, - ): - """Initialize a CustomAttnProcessor2_0. - Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are - layer-specific are passed to __init__(). - Args: - ip_adapter_weights: The IP-Adapter attention weights. ip_adapter_weights[i] contains the attention weights - for the i'th IP-Adapter. - """ + def __init__(self): + """Initialize a CustomAttnProcessor2_0.""" super().__init__() - self._ip_adapter_attention_weights = ip_adapter_attention_weights + self._ip_adapter_attention_weights = [] + + def add_ip_adapter(self, ip_adapter: IPAdapterAttentionWeights) -> int: + self._ip_adapter_attention_weights.append(ip_adapter) + return len(self._ip_adapter_attention_weights) - 1 # idx def __call__( self, diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py index 792c97114da..79456a0bf8e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_ip_data.py @@ -25,11 +25,16 @@ def __init__( # scales[i] contains the attention scale for the i'th IP-Adapter. self.scales = scales + self.masks = masks + self.dtype = dtype + self.device = device + self.max_downscale_factor = max_downscale_factor + # The IP-Adapter masks. # self._masks_by_seq_len[s] contains the spatial masks for the downsampling level with query sequence length of # s. It has shape (batch_size, num_ip_images, query_seq_len, 1). The masks have values of 1.0 for included # regions and 0.0 for excluded regions. - self._masks_by_seq_len = self._prepare_masks(masks, max_downscale_factor, device, dtype) + self._masks_by_seq_len = None # self._prepare_masks(masks, max_downscale_factor, device, dtype) def _prepare_masks( self, masks: list[torch.Tensor], max_downscale_factor: int, device: torch.device, dtype: torch.dtype @@ -69,4 +74,13 @@ def _prepare_masks( def get_masks(self, query_seq_len: int) -> torch.Tensor: """Get the mask for the given query sequence length.""" + if self._masks_by_seq_len is None: + self._masks_by_seq_len = self._prepare_masks(self.masks, self.max_downscale_factor, self.device, self.dtype) return self._masks_by_seq_len[query_seq_len] + + def add(self, embeds: torch.Tensor, scale: float, mask: torch.Tensor): + if self._masks_by_seq_len is not None: + self._masks_by_seq_len = None + self.image_prompt_embeds.append(embeds) + self.scales.append(scale) + self.masks.append(mask) diff --git a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py index f09cc0a0d21..eddd31f0c42 100644 --- a/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py +++ b/invokeai/backend/stable_diffusion/diffusion/regional_prompt_data.py @@ -1,9 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import torch import torch.nn.functional as F -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - TextConditioningRegions, -) +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( + TextConditioningRegions, + ) class RegionalPromptData: diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py deleted file mode 100644 index f418133e49f..00000000000 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ /dev/null @@ -1,496 +0,0 @@ -from __future__ import annotations - -import math -from typing import Any, Callable, Optional, Union - -import torch -from typing_extensions import TypeAlias - -from invokeai.app.services.config.config_default import get_config -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( - IPAdapterData, - Range, - TextConditioningData, - TextConditioningRegions, -) -from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData -from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData - -ModelForwardCallback: TypeAlias = Union[ - # x, t, conditioning, Optional[cross-attention kwargs] - Callable[ - [torch.Tensor, torch.Tensor, torch.Tensor, Optional[dict[str, Any]]], - torch.Tensor, - ], - Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], -] - - -class InvokeAIDiffuserComponent: - """ - The aim of this component is to provide a single place for code that can be applied identically to - all InvokeAI diffusion procedures. - - At the moment it includes the following features: - * Cross attention control ("prompt2prompt") - * Hybrid conditioning (used for inpainting) - """ - - debug_thresholding = False - sequential_guidance = False - - def __init__( - self, - model, - model_forward_callback: ModelForwardCallback, - ): - """ - :param model: the unet model to pass through to cross attention control - :param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning) - """ - config = get_config() - self.conditioning = None - self.model = model - self.model_forward_callback = model_forward_callback - self.sequential_guidance = config.sequential_guidance - - def do_controlnet_step( - self, - control_data, - sample: torch.Tensor, - timestep: torch.Tensor, - step_index: int, - total_step_count: int, - conditioning_data: TextConditioningData, - ): - down_block_res_samples, mid_block_res_sample = None, None - - # control_data should be type List[ControlNetData] - # this loop covers both ControlNet (one ControlNetData in list) - # and MultiControlNet (multiple ControlNetData in list) - for _i, control_datum in enumerate(control_data): - control_mode = control_datum.control_mode - # soft_injection and cfg_injection are the two ControlNet control_mode booleans - # that are combined at higher level to make control_mode enum - # soft_injection determines whether to do per-layer re-weighting adjustment (if True) - # or default weighting (if False) - soft_injection = control_mode == "more_prompt" or control_mode == "more_control" - # cfg_injection = determines whether to apply ControlNet to only the conditional (if True) - # or the default both conditional and unconditional (if False) - cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" - - first_control_step = math.floor(control_datum.begin_step_percent * total_step_count) - last_control_step = math.ceil(control_datum.end_step_percent * total_step_count) - # only apply controlnet if current step is within the controlnet's begin/end step range - if step_index >= first_control_step and step_index <= last_control_step: - if cfg_injection: - sample_model_input = sample - else: - # expand the latents input to control model if doing classifier free guidance - # (which I think for now is always true, there is conditional elsewhere that stops execution if - # classifier_free_guidance is <= 1.0 ?) - sample_model_input = torch.cat([sample] * 2) - - added_cond_kwargs = None - - if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": conditioning_data.cond_text.pooled_embeds, - "time_ids": conditioning_data.cond_text.add_time_ids, - } - encoder_hidden_states = conditioning_data.cond_text.embeds - encoder_attention_mask = None - else: - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.uncond_text.pooled_embeds, - conditioning_data.cond_text.pooled_embeds, - ], - dim=0, - ), - "time_ids": torch.cat( - [ - conditioning_data.uncond_text.add_time_ids, - conditioning_data.cond_text.add_time_ids, - ], - dim=0, - ), - } - ( - encoder_hidden_states, - encoder_attention_mask, - ) = self._concat_conditionings_for_batch( - conditioning_data.uncond_text.embeds, - conditioning_data.cond_text.embeds, - ) - if isinstance(control_datum.weight, list): - # if controlnet has multiple weights, use the weight for the current step - controlnet_weight = control_datum.weight[step_index] - else: - # if controlnet has a single weight, use it for all steps - controlnet_weight = control_datum.weight - - # controlnet(s) inference - down_samples, mid_sample = control_datum.model( - sample=sample_model_input, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=control_datum.image_tensor, - conditioning_scale=controlnet_weight, # controlnet specific, NOT the guidance scale - encoder_attention_mask=encoder_attention_mask, - added_cond_kwargs=added_cond_kwargs, - guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel - return_dict=False, - ) - if cfg_injection: - # Inferred ControlNet only for the conditional batch. - # To apply the output of ControlNet to both the unconditional and conditional batches, - # prepend zeros for unconditional batch - down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] - mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) - - if down_block_res_samples is None and mid_block_res_sample is None: - down_block_res_samples, mid_block_res_sample = down_samples, mid_sample - else: - # add controlnet outputs together if have multiple controlnets - down_block_res_samples = [ - samples_prev + samples_curr - for samples_prev, samples_curr in zip(down_block_res_samples, down_samples, strict=True) - ] - mid_block_res_sample += mid_sample - - return down_block_res_samples, mid_block_res_sample - - def do_unet_step( - self, - sample: torch.Tensor, - timestep: torch.Tensor, - conditioning_data: TextConditioningData, - ip_adapter_data: Optional[list[IPAdapterData]], - step_index: int, - total_step_count: int, - down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet - mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet - down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter - ): - if self.sequential_guidance: - ( - unconditioned_next_x, - conditioned_next_x, - ) = self._apply_standard_conditioning_sequentially( - x=sample, - sigma=timestep, - conditioning_data=conditioning_data, - ip_adapter_data=ip_adapter_data, - step_index=step_index, - total_step_count=total_step_count, - down_block_additional_residuals=down_block_additional_residuals, - mid_block_additional_residual=mid_block_additional_residual, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, - ) - else: - ( - unconditioned_next_x, - conditioned_next_x, - ) = self._apply_standard_conditioning( - x=sample, - sigma=timestep, - conditioning_data=conditioning_data, - ip_adapter_data=ip_adapter_data, - step_index=step_index, - total_step_count=total_step_count, - down_block_additional_residuals=down_block_additional_residuals, - mid_block_additional_residual=mid_block_additional_residual, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, - ) - - return unconditioned_next_x, conditioned_next_x - - def _concat_conditionings_for_batch(self, unconditioning, conditioning): - def _pad_conditioning(cond, target_len, encoder_attention_mask): - conditioning_attention_mask = torch.ones( - (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype - ) - - if cond.shape[1] < max_len: - conditioning_attention_mask = torch.cat( - [ - conditioning_attention_mask, - torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype), - ], - dim=1, - ) - - cond = torch.cat( - [ - cond, - torch.zeros( - (cond.shape[0], max_len - cond.shape[1], cond.shape[2]), - device=cond.device, - dtype=cond.dtype, - ), - ], - dim=1, - ) - - if encoder_attention_mask is None: - encoder_attention_mask = conditioning_attention_mask - else: - encoder_attention_mask = torch.cat( - [ - encoder_attention_mask, - conditioning_attention_mask, - ] - ) - - return cond, encoder_attention_mask - - encoder_attention_mask = None - if unconditioning.shape[1] != conditioning.shape[1]: - max_len = max(unconditioning.shape[1], conditioning.shape[1]) - unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) - conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) - - return torch.cat([unconditioning, conditioning]), encoder_attention_mask - - # methods below are called from do_diffusion_step and should be considered private to this class. - - def _apply_standard_conditioning( - self, - x: torch.Tensor, - sigma: torch.Tensor, - conditioning_data: TextConditioningData, - ip_adapter_data: Optional[list[IPAdapterData]], - step_index: int, - total_step_count: int, - down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet - mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet - down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter - ) -> tuple[torch.Tensor, torch.Tensor]: - """Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at - the cost of higher memory usage. - """ - x_twice = torch.cat([x] * 2) - sigma_twice = torch.cat([sigma] * 2) - - cross_attention_kwargs = {} - if ip_adapter_data is not None: - ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] - # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). - image_prompt_embeds = [ - torch.stack([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]) - for ipa_conditioning in ip_adapter_conditioning - ] - scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] - ip_masks = [ipa.mask for ipa in ip_adapter_data] - regional_ip_data = RegionalIPData( - image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device - ) - cross_attention_kwargs["regional_ip_data"] = regional_ip_data - - added_cond_kwargs = None - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": torch.cat( - [ - # TODO: how to pad? just by zeros? or even truncate? - conditioning_data.uncond_text.pooled_embeds, - conditioning_data.cond_text.pooled_embeds, - ], - dim=0, - ), - "time_ids": torch.cat( - [ - conditioning_data.uncond_text.add_time_ids, - conditioning_data.cond_text.add_time_ids, - ], - dim=0, - ), - } - - if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None: - # TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings - # and masks are not changing from step-to-step, so this really only needs to be done once. While this seems - # painfully inefficient, the time spent is typically negligible compared to the forward inference pass of - # the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly - # awkward to handle both standard conditioning and sequential conditioning further up the stack. - regions = [] - for c, r in [ - (conditioning_data.uncond_text, conditioning_data.uncond_regions), - (conditioning_data.cond_text, conditioning_data.cond_regions), - ]: - if r is None: - # Create a dummy mask and range for text conditioning that doesn't have region masks. - _, _, h, w = x.shape - r = TextConditioningRegions( - masks=torch.ones((1, 1, h, w), dtype=x.dtype), - ranges=[Range(start=0, end=c.embeds.shape[1])], - ) - regions.append(r) - - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( - regions=regions, device=x.device, dtype=x.dtype - ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count - - both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( - conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds - ) - both_results = self.model_forward_callback( - x_twice, - sigma_twice, - both_conditionings, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - down_block_additional_residuals=down_block_additional_residuals, - mid_block_additional_residual=mid_block_additional_residual, - down_intrablock_additional_residuals=down_intrablock_additional_residuals, - added_cond_kwargs=added_cond_kwargs, - ) - unconditioned_next_x, conditioned_next_x = both_results.chunk(2) - return unconditioned_next_x, conditioned_next_x - - def _apply_standard_conditioning_sequentially( - self, - x: torch.Tensor, - sigma, - conditioning_data: TextConditioningData, - ip_adapter_data: Optional[list[IPAdapterData]], - step_index: int, - total_step_count: int, - down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet - mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet - down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter - ): - """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of - slower execution speed. - """ - # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet - # and T2I-Adapter residuals into two chunks. - uncond_down_block, cond_down_block = None, None - if down_block_additional_residuals is not None: - uncond_down_block, cond_down_block = [], [] - for down_block in down_block_additional_residuals: - _uncond_down, _cond_down = down_block.chunk(2) - uncond_down_block.append(_uncond_down) - cond_down_block.append(_cond_down) - - uncond_down_intrablock, cond_down_intrablock = None, None - if down_intrablock_additional_residuals is not None: - uncond_down_intrablock, cond_down_intrablock = [], [] - for down_intrablock in down_intrablock_additional_residuals: - _uncond_down, _cond_down = down_intrablock.chunk(2) - uncond_down_intrablock.append(_uncond_down) - cond_down_intrablock.append(_cond_down) - - uncond_mid_block, cond_mid_block = None, None - if mid_block_additional_residual is not None: - uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) - - ##################### - # Unconditioned pass - ##################### - - cross_attention_kwargs = {} - - # Prepare IP-Adapter cross-attention kwargs for the unconditioned pass. - if ip_adapter_data is not None: - ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] - # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - image_prompt_embeds = [ - torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - - scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] - ip_masks = [ipa.mask for ipa in ip_adapter_data] - regional_ip_data = RegionalIPData( - image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device - ) - cross_attention_kwargs["regional_ip_data"] = regional_ip_data - - # Prepare SDXL conditioning kwargs for the unconditioned pass. - added_cond_kwargs = None - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": conditioning_data.uncond_text.pooled_embeds, - "time_ids": conditioning_data.uncond_text.add_time_ids, - } - - # Prepare prompt regions for the unconditioned pass. - if conditioning_data.uncond_regions is not None: - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( - regions=[conditioning_data.uncond_regions], device=x.device, dtype=x.dtype - ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count - - # Run unconditioned UNet denoising (i.e. negative prompt). - unconditioned_next_x = self.model_forward_callback( - x, - sigma, - conditioning_data.uncond_text.embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=uncond_down_block, - mid_block_additional_residual=uncond_mid_block, - down_intrablock_additional_residuals=uncond_down_intrablock, - added_cond_kwargs=added_cond_kwargs, - ) - - ################### - # Conditioned pass - ################### - - cross_attention_kwargs = {} - - if ip_adapter_data is not None: - ip_adapter_conditioning = [ipa.ip_adapter_conditioning for ipa in ip_adapter_data] - # Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len). - image_prompt_embeds = [ - torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0) - for ipa_conditioning in ip_adapter_conditioning - ] - - scales = [ipa.scale_for_step(step_index, total_step_count) for ipa in ip_adapter_data] - ip_masks = [ipa.mask for ipa in ip_adapter_data] - regional_ip_data = RegionalIPData( - image_prompt_embeds=image_prompt_embeds, scales=scales, masks=ip_masks, dtype=x.dtype, device=x.device - ) - cross_attention_kwargs["regional_ip_data"] = regional_ip_data - - # Prepare SDXL conditioning kwargs for the conditioned pass. - added_cond_kwargs = None - if conditioning_data.is_sdxl(): - added_cond_kwargs = { - "text_embeds": conditioning_data.cond_text.pooled_embeds, - "time_ids": conditioning_data.cond_text.add_time_ids, - } - - # Prepare prompt regions for the conditioned pass. - if conditioning_data.cond_regions is not None: - cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData( - regions=[conditioning_data.cond_regions], device=x.device, dtype=x.dtype - ) - cross_attention_kwargs["percent_through"] = step_index / total_step_count - - # Run conditioned UNet denoising (i.e. positive prompt). - conditioned_next_x = self.model_forward_callback( - x, - sigma, - conditioning_data.cond_text.embeds, - cross_attention_kwargs=cross_attention_kwargs, - down_block_additional_residuals=cond_down_block, - mid_block_additional_residual=cond_mid_block, - down_intrablock_additional_residuals=cond_down_intrablock, - added_cond_kwargs=added_cond_kwargs, - ) - return unconditioned_next_x, conditioned_next_x - - def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale): - # to scale how much effect conditioning has, calculate the changes it does and then scale that - scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale - combined_next_x = unconditioned_next_x + scaled_delta - return combined_next_x diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py deleted file mode 100644 index ac00a8e06ea..00000000000 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ /dev/null @@ -1,68 +0,0 @@ -from contextlib import contextmanager -from typing import List, Optional, TypedDict - -from diffusers.models import UNet2DConditionModel - -from invokeai.backend.ip_adapter.ip_adapter import IPAdapter -from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( - CustomAttnProcessor2_0, - IPAdapterAttentionWeights, -) - - -class UNetIPAdapterData(TypedDict): - ip_adapter: IPAdapter - target_blocks: List[str] - - -class UNetAttentionPatcher: - """A class for patching a UNet with CustomAttnProcessor2_0 attention layers.""" - - def __init__(self, ip_adapter_data: Optional[List[UNetIPAdapterData]]): - self._ip_adapters = ip_adapter_data - - def _prepare_attention_processors(self, unet: UNet2DConditionModel): - """Prepare a dict of attention processors that can be injected into a unet, and load the IP-Adapter attention - weights into them (if IP-Adapters are being applied). - Note that the `unet` param is only used to determine attention block dimensions and naming. - """ - # Construct a dict of attention processors based on the UNet's architecture. - attn_procs = {} - for idx, name in enumerate(unet.attn_processors.keys()): - if name.endswith("attn1.processor") or self._ip_adapters is None: - # "attn1" processors do not use IP-Adapters. - attn_procs[name] = CustomAttnProcessor2_0() - else: - # Collect the weights from each IP Adapter for the idx'th attention processor. - ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = [] - - for ip_adapter in self._ip_adapters: - ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) - skip = True - for block in ip_adapter["target_blocks"]: - if block in name: - skip = False - break - ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights( - ip_adapter_weights=ip_adapter_weights, skip=skip - ) - ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights) - - attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection) - - return attn_procs - - @contextmanager - def apply_ip_adapter_attention(self, unet: UNet2DConditionModel): - """A context manager that patches `unet` with CustomAttnProcessor2_0 attention layers.""" - attn_procs = self._prepare_attention_processors(unet) - orig_attn_processors = unet.attn_processors - - try: - # Note to future devs: set_attn_processor(...) does something slightly unexpected - it pops elements from - # the passed dict. So, if you wanted to keep the dict for future use, you'd have to make a - # moderately-shallow copy of it. E.g. `attn_procs_copy = {k: v for k, v in attn_procs.items()}`. - unet.set_attn_processor(attn_procs) - yield None - finally: - unet.set_attn_processor(orig_attn_processors) diff --git a/invokeai/backend/stable_diffusion/extensions/__init__.py b/invokeai/backend/stable_diffusion/extensions/__init__.py new file mode 100644 index 00000000000..5812b2874e3 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/__init__.py @@ -0,0 +1,30 @@ +""" +Initialization file for the invokeai.backend.stable_diffusion.extensions package +""" + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt +from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt +from invokeai.backend.stable_diffusion.extensions.inpaint import InpaintExt +from invokeai.backend.stable_diffusion.extensions.ip_adapter import IPAdapterExt +from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt +from invokeai.backend.stable_diffusion.extensions.preview import PipelineIntermediateState, PreviewExt +from invokeai.backend.stable_diffusion.extensions.rescale import RescaleCFGExt +from invokeai.backend.stable_diffusion.extensions.seamless import SeamlessExt +from invokeai.backend.stable_diffusion.extensions.t2i_adapter import T2IAdapterExt +from invokeai.backend.stable_diffusion.extensions.tiled_denoise import TiledDenoiseExt + +__all__ = [ + "PipelineIntermediateState", + "ExtensionBase", + "InpaintExt", + "PreviewExt", + "RescaleCFGExt", + "T2IAdapterExt", + "ControlNetExt", + "IPAdapterExt", + "TiledDenoiseExt", + "SeamlessExt", + "FreeUExt", + "LoRAPatcherExt", +] diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py new file mode 100644 index 00000000000..d3414eea6f0 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -0,0 +1,58 @@ +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Callable, Dict, List, Optional + +import torch +from diffusers import UNet2DConditionModel + + +@dataclass +class InjectionInfo: + type: str + name: str + order: Optional[str] + function: Callable + + +def modifier(name: str, order: str = "any"): + def _decorator(func): + func.__inj_info__ = { + "type": "modifier", + "name": name, + "order": order, + } + return func + + return _decorator + + +def override(name: str): + def _decorator(func): + func.__inj_info__ = { + "type": "override", + "name": name, + "order": None, + } + return func + + return _decorator + + +class ExtensionBase: + def __init__(self, priority: int): + self.priority = priority + self.injections: List[InjectionInfo] = [] + for func_name in dir(self): + func = getattr(self, func_name) + if not callable(func) or not hasattr(func, "__inj_info__"): + continue + + self.injections.append(InjectionInfo(**func.__inj_info__, function=func)) + + @contextmanager + def patch_attention_processor(self, attention_processor_cls: object): + yield None + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + yield None diff --git a/invokeai/backend/stable_diffusion/extensions/controlnet.py b/invokeai/backend/stable_diffusion/extensions/controlnet.py new file mode 100644 index 00000000000..ee6fd100dc0 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/controlnet.py @@ -0,0 +1,163 @@ +from __future__ import annotations + +import math +from contextlib import contextmanager +from typing import TYPE_CHECKING, List, Optional, Union + +import torch +from PIL.Image import Image + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.util.hotfixes import ControlNetModel + + +class ControlNetExt(ExtensionBase): + def __init__( + self, + model: ControlNetModel, + image: Image, + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + control_mode: str, + resize_mode: str, + priority: int, + ): + super().__init__(priority=priority) + self.model = model + self.image = image + self.weight = weight + self.begin_step_percent = begin_step_percent + self.end_step_percent = end_step_percent + self.control_mode = control_mode + self.resize_mode = resize_mode + + self.image_tensor: Optional[torch.Tensor] = None + + @contextmanager + def patch_attention_processor(self, attention_processor_cls): + try: + original_processors = self.model.attn_processors + self.model.set_attn_processor(attention_processor_cls()) + + yield None + finally: + self.model.set_attn_processor(original_processors) + + @modifier("pre_denoise_loop") + def resize_image(self, ctx: DenoiseContext): + _, _, latent_height, latent_width = ctx.latents.shape + image_height = latent_height * LATENT_SCALE_FACTOR + image_width = latent_width * LATENT_SCALE_FACTOR + + self.image_tensor = prepare_control_image( + image=self.image, + do_classifier_free_guidance=True, + width=image_width, + height=image_height, + # batch_size=batch_size * num_images_per_prompt, + # num_images_per_prompt=num_images_per_prompt, + device=ctx.latents.device, + dtype=ctx.latents.dtype, + control_mode=self.control_mode, + resize_mode=self.resize_mode, + ) + + @modifier("pre_unet_forward") + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.timesteps) + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + # convert mode to internal flags + soft_injection = self.control_mode in ["more_prompt", "more_control"] + cfg_injection = self.control_mode in ["more_control", "unbalanced"] + + # no negative conditioning in cfg_injection mode + if cfg_injection: + if ctx.conditioning_mode == "negative": + return + down_samples, mid_sample = self._run(ctx, soft_injection, "positive") + + if ctx.conditioning_mode == "both": + # add zeros as samples for negative conditioning + down_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_samples] + mid_sample = torch.cat([torch.zeros_like(mid_sample), mid_sample]) + + else: + down_samples, mid_sample = self._run(ctx, soft_injection, ctx.conditioning_mode) + + if ( + ctx.unet_kwargs.down_block_additional_residuals is None + and ctx.unet_kwargs.mid_block_additional_residual is None + ): + ctx.unet_kwargs.down_block_additional_residuals, ctx.unet_kwargs.mid_block_additional_residual = ( + down_samples, + mid_sample, + ) + else: + # add controlnet outputs together if have multiple controlnets + ctx.unet_kwargs.down_block_additional_residuals = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip( + ctx.unet_kwargs.down_block_additional_residuals, down_samples, strict=True + ) + ] + ctx.unet_kwargs.mid_block_additional_residual += mid_sample + + def _run(self, ctx: DenoiseContext, soft_injection, conditioning_mode): + total_steps = len(ctx.timesteps) + model_input = ctx.latent_model_input + if conditioning_mode == "both": + model_input = torch.cat([model_input] * 2) + + cn_unet_kwargs = UNetKwargs( + sample=model_input, + timestep=ctx.timestep, + encoder_hidden_states=None, # set later by conditoning + cross_attention_kwargs=dict( # noqa: C408 + percent_through=ctx.step_index / total_steps, + ), + ) + + ctx.conditioning_data.to_unet_kwargs(cn_unet_kwargs, conditioning_mode=conditioning_mode) + + # get static weight, or weight corresponding to current step + weight = self.weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + tmp_kwargs = vars(cn_unet_kwargs) + tmp_kwargs.pop("down_block_additional_residuals", None) + tmp_kwargs.pop("mid_block_additional_residual", None) + tmp_kwargs.pop("down_intrablock_additional_residuals", None) + + image_tensor = self.image_tensor + tile_coords = ctx.extra.get("tile_coords", None) + if tile_coords is not None: + image_tensor = image_tensor[ + :, + :, + tile_coords.top * LATENT_SCALE_FACTOR : tile_coords.bottom * LATENT_SCALE_FACTOR, + tile_coords.left * LATENT_SCALE_FACTOR : tile_coords.right * LATENT_SCALE_FACTOR, + ] + + # controlnet(s) inference + down_samples, mid_sample = self.model( + controlnet_cond=image_tensor, + conditioning_scale=weight, # controlnet specific, NOT the guidance scale + guess_mode=soft_injection, # this is still called guess_mode in diffusers ControlNetModel + return_dict=False, + **vars(cn_unet_kwargs), + ) + + return down_samples, mid_sample diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py new file mode 100644 index 00000000000..f57692dc574 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Optional + +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + +if TYPE_CHECKING: + from invokeai.app.shared.models import FreeUConfig + + +class FreeUExt(ExtensionBase): + def __init__( + self, + freeu_config: Optional[FreeUConfig], + priority: int, + ): + super().__init__(priority=priority) + self.freeu_config = freeu_config + + @contextmanager + def patch_unet(self, state_dict: dict, unet: UNet2DConditionModel): + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if self.freeu_config is not None: + unet.enable_freeu( + b1=self.freeu_config.b1, + b2=self.freeu_config.b2, + s1=self.freeu_config.s1, + s2=self.freeu_config.s2, + ) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions/inpaint.py b/invokeai/backend/stable_diffusion/extensions/inpaint.py new file mode 100644 index 00000000000..e385107c979 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/inpaint.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import einops +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class InpaintExt(ExtensionBase): + def __init__( + self, + mask: Optional[torch.Tensor], + masked_latents: Optional[torch.Tensor], + is_gradient_mask: bool, + priority: int, + ): + super().__init__(priority=priority) + self.mask = mask + self.masked_latents = masked_latents + self.is_gradient_mask = is_gradient_mask + self.noise = None + + def _is_inpaint_model(self, unet: UNet2DConditionModel): + return unet.conv_in.in_channels == 9 + + def _apply_mask(self, ctx: DenoiseContext, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + batch_size = latents.size(0) + mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + if t.dim() == 0: + # some schedulers expect t to be one-dimensional. + # TODO: file diffusers bug about inconsistency? + t = einops.repeat(t, "-> batch", batch=batch_size) + # Noise shouldn't be re-randomized between steps here. The multistep schedulers + # get very confused about what is happening from step to step when we do that. + mask_latents = ctx.scheduler.add_noise(ctx.orig_latents, self.noise, t) + # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? + # mask_latents = self.scheduler.scale_model_input(mask_latents, t) + mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) + if self.is_gradient_mask: + threshhold = (t.item()) / ctx.scheduler.config.num_train_timesteps + mask_bool = mask > threshhold # I don't know when mask got inverted, but it did + masked_input = torch.where(mask_bool, latents, mask_latents) + else: + masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + return masked_input + + @modifier("pre_denoise_loop") + def init_tensors(self, ctx: DenoiseContext): + if self._is_inpaint_model(ctx.unet): + if self.mask is None: + self.mask = torch.ones_like(ctx.latents[:1, :1]) + self.mask.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + if self.masked_latents is None: + self.masked_latents = torch.zeros_like(ctx.latents[:1]) + self.masked_latents.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + + else: + self.noise = ctx.noise + if self.noise is None: + self.noise = torch.randn( + ctx.orig_latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(ctx.seed), + ).to(device=ctx.orig_latents.device, dtype=ctx.orig_latents.dtype) + + # do first to make other extensions works with changed latents + @modifier("pre_step", order="first") + def apply_mask_to_latents(self, ctx: DenoiseContext): + if self._is_inpaint_model(ctx.unet) or self.mask is None: + return + ctx.latents = self._apply_mask(ctx, ctx.latents, ctx.timestep) + + # do last so that other extensions works with normal latents + @modifier("pre_unet_forward", order="last") + def append_inpaint_layers(self, ctx: DenoiseContext): + if not self._is_inpaint_model(ctx.unet): + return + + batch_size = ctx.unet_kwargs.sample.shape[0] + b_mask = torch.cat([self.mask] * batch_size) + b_masked_latents = torch.cat([self.masked_latents] * batch_size) + ctx.unet_kwargs.sample = torch.cat( + [ctx.unet_kwargs.sample, b_mask, b_masked_latents], + dim=1, + ) + + @modifier("post_step", order="first") + def apply_mask_to_preview(self, ctx: DenoiseContext): + if self._is_inpaint_model(ctx.unet) or self.mask is None: + return + + timestep = ctx.scheduler.timesteps[-1] + if hasattr(ctx.step_output, "denoised"): + ctx.step_output.denoised = self._apply_mask(ctx, ctx.step_output.denoised, timestep) + elif hasattr(ctx.step_output, "pred_original_sample"): + ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.pred_original_sample, timestep) + else: + ctx.step_output.pred_original_sample = self._apply_mask(ctx, ctx.step_output.prev_sample, timestep) + + @modifier("post_denoise_loop") # last? + def restore_unmasked(self, ctx: DenoiseContext): + if self.mask is None: + return + + # restore unmasked part after the last step is completed + # in-process masking happens before each step + if self.is_gradient_mask: + ctx.latents = torch.where(self.mask > 0, ctx.latents, ctx.orig_latents) + else: + ctx.latents = torch.lerp( + ctx.orig_latents, + ctx.latents.to(dtype=ctx.orig_latents.dtype), + self.mask.to(dtype=ctx.orig_latents.dtype), + ) diff --git a/invokeai/backend/stable_diffusion/extensions/ip_adapter.py b/invokeai/backend/stable_diffusion/extensions/ip_adapter.py new file mode 100644 index 00000000000..c6324de4e6a --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/ip_adapter.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import math +from contextlib import ExitStack, contextmanager +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import torch +import torchvision +from diffusers import UNet2DConditionModel +from PIL.Image import Image +from transformers import CLIPVisionModelWithProjection + +from invokeai.backend.ip_adapter.ip_adapter import IPAdapter +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.custom_atttention import ( + CustomAttnProcessor2_0, + IPAdapterAttentionWeights, +) +from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier +from invokeai.backend.util.mask import to_standard_float_mask + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class IPAdapterExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + exit_stack: ExitStack, + model_id: ModelIdentifierField, + image_encoder_model_id: ModelIdentifierField, + images: List[Image], + mask: torch.Tensor, + target_blocks: List[str], + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + priority: int, + ): + super().__init__(priority=priority) + self.node_context = node_context + self.exit_stack = exit_stack + self.model_id = model_id + self.image_encoder_model_id = image_encoder_model_id + self.images = images + self.mask = mask + self.target_blocks = target_blocks + self.weight = weight + self.begin_step_percent = begin_step_percent + self.end_step_percent = end_step_percent + + self.model: Optional[IPAdapter] = None + self.conditioning: Optional[IPAdapterConditioningInfo] = None + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + try: + for idx, name in enumerate(unet.attn_processors.keys()): + if name.endswith("attn1.processor"): + continue + + ip_adapter_weights = self.model.attn_weights.get_attention_processor_weights(idx) + skip = True + for block in self.target_blocks: + if block in name: + skip = False + break + + assert isinstance(unet.attn_processors[name], CustomAttnProcessor2_0) + unet.attn_processors[name].add_ip_adapter( + IPAdapterAttentionWeights( + ip_adapter_weights=ip_adapter_weights, + skip=skip, + ) + ) + + yield None + + finally: + # nop, as it unpatched with attention processor + pass + + @modifier("pre_unet_load") + def preprocess_images(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + """Run the IPAdapter CLIPVisionModel, returning image prompt embeddings.""" + + # HACK: save a bit of memory by not loading ip attention weights on image processing + # and by loading only attention weight on denoising + if True: + with self.node_context.models.load(self.image_encoder_model_id) as image_encoder_model: + self.model = self.node_context.models.load(self.model_id).model + assert isinstance(self.model, IPAdapter) + assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) + + st_device = self.model.device + st_dtype = self.model.dtype + self.model.device = image_encoder_model.device + self.model.dtype = image_encoder_model.dtype + + def _move_ip_adapter_to_storage_device(model): + model.device = st_device + model.dtype = st_dtype + model._image_proj_model.to(device=st_device, dtype=st_dtype) + model.attn_weights.to(device=st_device, dtype=st_dtype) + + # Get image embeddings from CLIP(image_encoder_model) and ImageProjModel(_image_proj_model). + try: + self.model._image_proj_model.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype) + positive_img_prompt_embeds, negative_img_prompt_embeds = self.model.get_image_embeds( + self.images, image_encoder_model + ) + self.model._image_proj_model.to(device=st_device, dtype=st_dtype) + except: + _move_ip_adapter_to_storage_device(self.model) + raise + + # load attn weights to device + self.model.attn_weights.to(device=ctx.latents.device, dtype=ctx.latents.dtype) + # move back to storage device on __exit__ + self.exit_stack.callback(_move_ip_adapter_to_storage_device, self.model) + + else: + self.model = self.exit_stack.enter_context(self.node_context.models.load(self.model_id)) + with self.node_context.models.load(self.image_encoder_model_id) as image_encoder_model: + assert isinstance(self.model, IPAdapter) + assert isinstance(image_encoder_model, CLIPVisionModelWithProjection) + # Get image embeddings from CLIP and ImageProjModel. + positive_img_prompt_embeds, negative_img_prompt_embeds = self.model.get_image_embeds( + self.images, image_encoder_model + ) + + self.conditioning = IPAdapterConditioningInfo(positive_img_prompt_embeds, negative_img_prompt_embeds) + + _, _, latent_height, latent_width = ctx.latents.shape + self.mask = self._preprocess_regional_prompt_mask( + self.mask, latent_height, latent_width, dtype=ctx.latents.dtype + ) + + @staticmethod + def _preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + + @modifier("pre_unet_forward") + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.timesteps) + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + weight = self.weight + if isinstance(weight, List): + weight = weight[ctx.step_index] + + if ctx.conditioning_mode == "both": + embeds = torch.stack( + [self.conditioning.uncond_image_prompt_embeds, self.conditioning.cond_image_prompt_embeds] + ) + elif ctx.conditioning_mode == "negative": + embeds = torch.stack([self.conditioning.uncond_image_prompt_embeds]) + else: # elif ctx.conditioning_mode == "positive": + embeds = torch.stack([self.conditioning.cond_image_prompt_embeds]) + + if ctx.unet_kwargs.cross_attention_kwargs is None: + ctx.unet_kwargs.cross_attention_kwargs = {} + + regional_ip_data = ctx.unet_kwargs.cross_attention_kwargs.get("regional_ip_data", None) + if regional_ip_data is None: + regional_ip_data = RegionalIPData( + image_prompt_embeds=[], + scales=[], + masks=[], + dtype=ctx.latent_model_input.dtype, + device=ctx.latent_model_input.device, + ) + ctx.unet_kwargs.cross_attention_kwargs.update( + regional_ip_data=regional_ip_data, + ) + + mask = self.mask + tile_coords = ctx.extra.get("tile_coords", None) + if tile_coords is not None: + mask = mask[:, :, tile_coords.top : tile_coords.bottom, tile_coords.left : tile_coords.right] + + regional_ip_data.add( + embeds=embeds, + scale=weight, + mask=mask, + ) diff --git a/invokeai/backend/stable_diffusion/extensions/lora_patcher.py b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py new file mode 100644 index 00000000000..e8d6cbe8ea0 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/lora_patcher.py @@ -0,0 +1,186 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.app.invocations.model import LoRAField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.lora import LoRAModelRaw # TODO: circular import + + +class LoRAPatcherExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + loras: List[LoRAField], + prefix: str, + priority: int, + ): + super().__init__(priority=priority) + self.loras = loras + self.prefix = prefix + self.node_context = node_context + + @contextmanager + def patch_unet(self, model_state_dict: Dict[str, torch.Tensor], model: UNet2DConditionModel): + def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self.loras: + lora_info = self.node_context.models.load(lora.lora) + lora_model = lora_info.model + from invokeai.backend.lora import LoRAModelRaw + + assert isinstance(lora_model, LoRAModelRaw) + yield (lora_model, lora.weight) + del lora_info + return + + yield self._patch_model( + model=model, + prefix=self.prefix, + loras=_lora_loader(), + model_state_dict=model_state_dict, + ) + + @classmethod + @contextmanager + def static_patch_model( + cls, + model: torch.nn.Module, + prefix: str, + loras: Iterator[Tuple[LoRAModelRaw, float]], + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + changed_keys = None + changed_unknown_keys = None + try: + changed_keys, changed_unknown_keys = cls._patch_model( + model=model, + prefix=prefix, + loras=loras, + model_state_dict=model_state_dict, + ) + + yield + + finally: + assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + if changed_keys: + for module_key in changed_keys: + weight = model_state_dict[module_key] + model.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) + if changed_unknown_keys: + for module_key, weight in changed_unknown_keys.items(): + model.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) + + @classmethod + def _patch_model( + cls, + model: UNet2DConditionModel, + prefix: str, + loras: Iterator[Tuple[LoRAModelRaw, float]], + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + """ + Apply one or more LoRAs to a model. + + :param model: The model to patch. + :param loras: An iterator that returns the LoRA to patch in and its patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ + if model_state_dict is None: + model_state_dict = {} + + changed_keys = set() + changed_unknown_keys = {} + with torch.no_grad(): + for lora, lora_weight in loras: + # assert lora.device.type == "cpu" + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This + # should be improved in the following ways: + # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a + # LoRA model is applied. + # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the + # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA + # weights to have valid keys. + assert isinstance(model, torch.nn.Module) + module_key, module = cls._resolve_lora_key(model, layer_key, prefix) + + # All of the LoRA weight calculations will be done on the same device as the module weight. + # (Performance will be best if this is a CUDA device.) + device = module.weight.device + dtype = module.weight.dtype + + if module_key not in changed_keys and module_key not in changed_unknown_keys: + if module_key is model_state_dict: + changed_keys.add(module_key) + else: + changed_unknown_keys[module_key] = module.weight.detach().to(device="cpu", copy=True) + + layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 + + # We intentionally move to the target device first, then cast. Experimentally, this was found to + # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the + # same thing in a single call to '.to(...)'. + layer.to(device=device, non_blocking=TorchDevice.get_non_blocking(device)) + layer.to(dtype=torch.float32, non_blocking=TorchDevice.get_non_blocking(device)) + # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA + # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. + layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) + layer.to( + device=TorchDevice.CPU_DEVICE, + non_blocking=TorchDevice.get_non_blocking(TorchDevice.CPU_DEVICE), + ) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + if module.weight.shape != layer_weight.shape: + # TODO: debug on lycoris + assert hasattr(layer_weight, "reshape") + layer_weight = layer_weight.reshape(module.weight.shape) + + assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! + module.weight += layer_weight.to(dtype=dtype, non_blocking=TorchDevice.get_non_blocking(device)) + + return changed_keys, changed_unknown_keys + + @staticmethod + def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]: + assert "." not in lora_key + + if not lora_key.startswith(prefix): + raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}") + + module = model + module_key = "" + key_parts = lora_key[len(prefix) :].split("_") + + submodule_name = key_parts.pop(0) + + while len(key_parts) > 0: + try: + module = module.get_submodule(submodule_name) + module_key += "." + submodule_name + submodule_name = key_parts.pop(0) + except Exception: + submodule_name += "_" + key_parts.pop(0) + + module = module.get_submodule(submodule_name) + module_key = (module_key + "." + submodule_name).lstrip(".") + + return (module_key, module) diff --git a/invokeai/backend/stable_diffusion/extensions/preview.py b/invokeai/backend/stable_diffusion/extensions/preview.py new file mode 100644 index 00000000000..e5d0bf7af6b --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/preview.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional + +import torch + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +@dataclass +class PipelineIntermediateState: + step: int + order: int + total_steps: int + timestep: int + latents: torch.Tensor + predicted_original: Optional[torch.Tensor] = None + + +class PreviewExt(ExtensionBase): + def __init__(self, callback: Callable[[PipelineIntermediateState], None], priority: int): + super().__init__(priority=priority) + self.callback = callback + + # do last so that all other changes shown + @modifier("pre_denoise_loop", order="last") + def initial_preview(self, ctx: DenoiseContext): + self.callback( + PipelineIntermediateState( + step=-1, + order=ctx.scheduler.order, + total_steps=len(ctx.timesteps), + timestep=int(ctx.scheduler.config.num_train_timesteps), # TODO: is there any code which uses it? + latents=ctx.latents, + ) + ) + + # do last so that all other changes shown + @modifier("post_step", order="last") + def step_preview(self, ctx: DenoiseContext): + if hasattr(ctx.step_output, "denoised"): + predicted_original = ctx.step_output.denoised + elif hasattr(ctx.step_output, "pred_original_sample"): + predicted_original = ctx.step_output.pred_original_sample + else: + predicted_original = ctx.step_output.prev_sample + + self.callback( + PipelineIntermediateState( + step=ctx.step_index, + order=ctx.scheduler.order, + total_steps=len(ctx.timesteps), + timestep=int(ctx.timestep), # TODO: is there any code which uses it? + latents=ctx.step_output.prev_sample, + predicted_original=predicted_original, # TODO: is there any reason for additional field? + ) + ) diff --git a/invokeai/backend/stable_diffusion/extensions/rescale.py b/invokeai/backend/stable_diffusion/extensions/rescale.py new file mode 100644 index 00000000000..9f9648261ef --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/rescale.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class RescaleCFGExt(ExtensionBase): + def __init__(self, guidance_rescale_multiplier: float, priority: int): + super().__init__(priority=priority) + self.guidance_rescale_multiplier = guidance_rescale_multiplier + + @staticmethod + def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7): + """Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf.""" + ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True) + ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True) + + x_rescaled = total_noise_pred * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred + return x_final + + @modifier("modify_noise_prediction") + def rescale_noise_pred(self, ctx: DenoiseContext): + if self.guidance_rescale_multiplier > 0: + ctx.noise_pred = self._rescale_cfg( + ctx.noise_pred, + ctx.positive_noise_pred, + self.guidance_rescale_multiplier, + ) diff --git a/invokeai/backend/stable_diffusion/extensions/seamless.py b/invokeai/backend/stable_diffusion/extensions/seamless.py new file mode 100644 index 00000000000..b99d5386096 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/seamless.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import Callable, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from diffusers import UNet2DConditionModel +from diffusers.models.lora import LoRACompatibleConv + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + + +class SeamlessExt(ExtensionBase): + def __init__( + self, + seamless_axes: List[str], + priority: int, + ): + super().__init__(priority=priority) + self.seamless_axes = seamless_axes + + @contextmanager + def patch_unet(self, state_dict: dict, unet: UNet2DConditionModel): + with self.static_patch_model( + model=unet, + model_state_dict=state_dict, + seamless_axes=self.seamless_axes, + ): + yield + + @classmethod + @contextmanager + def static_patch_model( + cls, + model: torch.nn.Module, + seamless_axes: List[str], + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ): + if not seamless_axes: + yield + return + + # override conv_forward + # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 + def _conv_forward_asymmetric( + self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None + ): + self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) + self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) + working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) + working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) + return torch.nn.functional.conv2d( + working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups + ) + + original_layers: List[Tuple[nn.Conv2d, Callable]] = [] + + try: + x_mode = "circular" if "x" in seamless_axes else "constant" + y_mode = "circular" if "y" in seamless_axes else "constant" + + conv_layers: List[torch.nn.Conv2d] = [] + + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): + conv_layers.append(module) + + for layer in conv_layers: + if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: + layer.lora_layer = lambda *x: 0 + original_layers.append((layer, layer._conv_forward)) + layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) + + yield + + finally: + for layer, orig_conv_forward in original_layers: + layer._conv_forward = orig_conv_forward diff --git a/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py new file mode 100644 index 00000000000..bdd3cf143ce --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/t2i_adapter.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import math +from contextlib import ExitStack +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +import torch +from diffusers import T2IAdapter +from PIL.Image import Image + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier + +# from invokeai.backend.model_manager import BaseModelType # TODO: + +if TYPE_CHECKING: + from invokeai.app.invocations.model import ModelIdentifierField + from invokeai.app.services.shared.invocation_context import InvocationContext + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class T2IAdapterExt(ExtensionBase): + def __init__( + self, + node_context: InvocationContext, + exit_stack: ExitStack, + model_id: ModelIdentifierField, + image: Image, + adapter_state: List[torch.Tensor], + weight: Union[float, List[float]], + begin_step_percent: float, + end_step_percent: float, + resize_mode: str, + priority: int, + ): + super().__init__(priority=priority) + self.node_context = node_context + self.exit_stack = exit_stack + self.model_id = model_id + self.image = image + self.weight = weight + self.resize_mode = resize_mode + self.begin_step_percent = begin_step_percent + self.end_step_percent = end_step_percent + + self.adapter_state: Optional[Tuple[torch.Tensor]] = None + + @staticmethod + def tile_coords_to_key(tile_coords): + return f"{tile_coords.top}:{tile_coords.bottom}:{tile_coords.left}:{tile_coords.right}" + + @modifier("pre_unet_load") + def run_model(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + t2i_model: T2IAdapter + with self.node_context.models.load(self.model_id) as t2i_model: + # used in tiled generation(maybe we should send more info in extra field instead) + self.latents_height = ctx.latents.shape[2] + self.latents_width = ctx.latents.shape[3] + + self.adapter_state = self._run_model( + ctx=ctx, + model=t2i_model, + image=self.image, + latents_height=self.latents_height, + latents_width=self.latents_width, + ) + + def _run_model( + self, + ctx: DenoiseContext, + model: T2IAdapter, + image: Image, + latents_height: int, + latents_width: int, + ): + model_config = self.node_context.models.get_config(self.model_id.key) + + # The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally. + from invokeai.backend.model_manager import BaseModelType + + if model_config.base == BaseModelType.StableDiffusion1: + max_unet_downscale = 8 + elif model_config.base == BaseModelType.StableDiffusionXL: + max_unet_downscale = 4 + else: + raise ValueError(f"Unexpected T2I-Adapter base model type: '{model_config.base}'.") + + input_height = latents_height // max_unet_downscale * model.total_downscale_factor + input_width = latents_width // max_unet_downscale * model.total_downscale_factor + + t2i_image = prepare_control_image( + image=image, + do_classifier_free_guidance=False, + width=input_width, + height=input_height, + num_channels=model.config["in_channels"], # mypy treats this as a FrozenDict + device=model.device, + dtype=model.dtype, + resize_mode=self.resize_mode, + ) + + adapter_state = model(t2i_image) + # if do_classifier_free_guidance: + for idx, value in enumerate(adapter_state): + adapter_state[idx] = torch.cat([value] * 2, dim=0) + + return adapter_state + + @modifier("pre_unet_forward") + def pre_unet_step(self, ctx: DenoiseContext): + # skip if model not active in current step + total_steps = len(ctx.timesteps) + first_step = math.floor(self.begin_step_percent * total_steps) + last_step = math.ceil(self.end_step_percent * total_steps) + if ctx.step_index < first_step or ctx.step_index > last_step: + return + + weight = self.weight + if isinstance(weight, list): + weight = weight[ctx.step_index] + + tile_coords = ctx.extra.get("tile_coords", None) + if tile_coords is not None: + if not isinstance(self.adapter_state, dict): + self.model = self.exit_stack.enter_context(self.node_context.models.load(self.model_id)) + self.adapter_state = {} + + tile_key = self.tile_coords_to_key(tile_coords) + if tile_key not in self.adapter_state: + tile_height = tile_coords.bottom - tile_coords.top + tile_width = tile_coords.right - tile_coords.left + + self.adapter_state[tile_key] = self._run_model( + ctx=ctx, + model=self.model, + latents_height=tile_height, + latents_width=tile_width, + image=self.image.resize( + (self.latents_width * LATENT_SCALE_FACTOR, self.latents_height * LATENT_SCALE_FACTOR) + ).crop( + ( + tile_coords.left * LATENT_SCALE_FACTOR, + tile_coords.top * LATENT_SCALE_FACTOR, + tile_coords.right * LATENT_SCALE_FACTOR, + tile_coords.bottom * LATENT_SCALE_FACTOR, + ) + ), + ) + + adapter_state = self.adapter_state[tile_key] + else: + adapter_state = self.adapter_state + + # TODO: conditioning_mode? + if ctx.unet_kwargs.down_intrablock_additional_residuals is None: + ctx.unet_kwargs.down_intrablock_additional_residuals = [v * weight for v in adapter_state] + else: + for i, value in enumerate(adapter_state): + ctx.unet_kwargs.down_intrablock_additional_residuals[i] += value * weight diff --git a/invokeai/backend/stable_diffusion/extensions/tiled_denoise.py b/invokeai/backend/stable_diffusion/extensions/tiled_denoise.py new file mode 100644 index 00000000000..640045b5c4f --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/tiled_denoise.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional + +import torch +from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput + +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR +from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, modifier, override +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager + + +class TiledDenoiseExt(ExtensionBase): + def __init__( + self, + tile_width: int, + tile_height: int, + tile_overlap: int, + priority: int, + ): + super().__init__(priority=priority) + self.tile_width = tile_width + self.tile_height = tile_height + self.tile_overlap = tile_overlap + + @dataclass + class FakeSchedulerOutput(SchedulerOutput): # BaseOutput + # prev_sample: torch.Tensor + pred_original_sample: Optional[torch.Tensor] = None + + @modifier("pre_denoise_loop") + def init_tiles(self, ctx: DenoiseContext): + _, _, latent_height, latent_width = ctx.latents.shape + latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR + latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR + latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR + + self.tiles = calc_tiles_min_overlap( + image_height=latent_height, + image_width=latent_width, + tile_height=latent_tile_height, + tile_width=latent_tile_width, + min_overlap=latent_tile_overlap, + ) + + @override("step") + def tiled_step(self, orig_step: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager): + batch_size, _, latent_height, latent_width = ctx.latents.shape + region_batch_schedulers: list[SchedulerMixin] = [copy.deepcopy(ctx.scheduler) for _ in self.tiles] + + merged_latents = torch.zeros_like(ctx.latents) + merged_latents_weights = torch.zeros( + (1, 1, latent_height, latent_width), device=ctx.latents.device, dtype=ctx.latents.dtype + ) + merged_pred_original: torch.Tensor | None = None + for region_idx, tile_region in enumerate(self.tiles): + # Crop the inputs to the region. + region_latents = ctx.latents[ + :, + :, + tile_region.coords.top : tile_region.coords.bottom, + tile_region.coords.left : tile_region.coords.right, + ] + + region_ctx = DenoiseContext(**vars(ctx)) + region_ctx.latents = region_latents + region_ctx.scheduler = region_batch_schedulers[region_idx] + # region_ctx.conditioning_data = region_conditioning.text_conditioning_data + region_ctx.extra["tile_coords"] = tile_region.coords + + # Run the denoising step on the region. + step_output = orig_step(region_ctx, ext_manager) + + # Store the results from the region. + # If two tiles overlap by more than the target overlap amount, crop the left and top edges of the + # affected tiles to achieve the target overlap. + target_overlap = self.tile_overlap // LATENT_SCALE_FACTOR + top_adjustment = max(0, tile_region.overlap.top - target_overlap) + left_adjustment = max(0, tile_region.overlap.left - target_overlap) + region_height_slice = slice(tile_region.coords.top + top_adjustment, tile_region.coords.bottom) + region_width_slice = slice(tile_region.coords.left + left_adjustment, tile_region.coords.right) + merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[ + :, :, top_adjustment:, left_adjustment: + ] + # For now, we treat every region as having the same weight. + merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0 + + # TODO: denoised + pred_orig_sample = getattr(step_output, "pred_original_sample", None) + if pred_orig_sample is not None: + # If one region has pred_original_sample, then we can assume that all regions will have it, because + # they all use the same scheduler. + if merged_pred_original is None: + merged_pred_original = torch.zeros_like(ctx.latents) + merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[ + :, :, top_adjustment:, left_adjustment: + ] + + # Normalize the merged results. + latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents) + # For debugging, uncomment this line to visualize the region seams: + # latents = torch.where(merged_latents_weights > 1, 0.0, latents) + predicted_original = None + if merged_pred_original is not None: + predicted_original = torch.where( + merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original + ) + + return self.FakeSchedulerOutput( + prev_sample=latents, + pred_original_sample=predicted_original, + ) diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py new file mode 100644 index 00000000000..260177219f4 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import ExitStack, contextmanager +from functools import partial +from typing import TYPE_CHECKING, Callable, Dict + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.util.devices import TorchDevice + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + from invokeai.backend.stable_diffusion.extensions import ExtensionBase + + +class ExtModifiersApi(ABC): + @abstractmethod + def pre_denoise_loop(self, ctx: DenoiseContext): + pass + + @abstractmethod + def post_denoise_loop(self, ctx: DenoiseContext): + pass + + @abstractmethod + def pre_step(self, ctx: DenoiseContext): + pass + + @abstractmethod + def post_step(self, ctx: DenoiseContext): + pass + + @abstractmethod + def modify_noise_prediction(self, ctx: DenoiseContext): + pass + + @abstractmethod + def pre_unet_forward(self, ctx: DenoiseContext): + pass + + @abstractmethod + def pre_unet_load(self, ctx: DenoiseContext, ext_manager: ExtensionsManager): + pass + + +class ExtOverridesApi(ABC): + @abstractmethod + def step(self, orig_func: Callable, ctx: DenoiseContext, ext_manager: ExtensionsManager): + pass + + @abstractmethod + def combine_noise(self, orig_func: Callable, ctx: DenoiseContext): + pass + + +class ProxyCallsClass: + def __init__(self, handler): + self._handler = handler + + def __getattr__(self, item): + return partial(self._handler, item) + + +class ModifierInjectionPoint: + def __init__(self): + self.first = [] + self.any = [] + self.last = [] + + def add(self, func: Callable, order: str): + if order == "first": + self.first.append(func) + elif order == "last": + self.last.append(func) + else: # elif order == "any": + self.any.append(func) + + def __call__(self, *args, **kwargs): + for func in self.first: + func(*args, **kwargs) + for func in self.any: + func(*args, **kwargs) + for func in reversed(self.last): + func(*args, **kwargs) + + +class ExtensionsManager: + def __init__(self): + self.extensions = [] + + self._overrides = {} + self._modifiers = {} + + self.modifiers: ExtModifiersApi = ProxyCallsClass(self.call_modifier) + self.overrides: ExtOverridesApi = ProxyCallsClass(self.call_override) + + def add_extension(self, ext: ExtensionBase): + self.extensions.append(ext) + ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) + + self._overrides.clear() + self._modifiers.clear() + + for ext in ordered_extensions: + for inj_info in ext.injections: + if inj_info.type == "modifier": + if inj_info.name not in self._modifiers: + self._modifiers[inj_info.name] = ModifierInjectionPoint() + self._modifiers[inj_info.name].add(inj_info.function, inj_info.order) + + else: + if inj_info.name in self._overrides: + raise Exception(f"Already overloaded - {inj_info.name}") + self._overrides[inj_info.name] = inj_info.function + + def call_modifier(self, name: str, *args, **kwargs): + if name in self._modifiers: + self._modifiers[name](*args, **kwargs) + + def call_override(self, name: str, orig_func: Callable, *args, **kwargs): + if name in self._overrides: + return self._overrides[name](orig_func, *args, **kwargs) + else: + return orig_func(*args, **kwargs) + + @contextmanager + def patch_attention_processor(self, unet: UNet2DConditionModel, attn_processor_cls: object): + unet_orig_processors = unet.attn_processors + exit_stack = ExitStack() + try: + # just to be sure that attentions have not same processor instance + attn_procs = {} + for name in unet.attn_processors.keys(): + attn_procs[name] = attn_processor_cls() + unet.set_attn_processor(attn_procs) + + for ext in self.extensions: + exit_stack.enter_context(ext.patch_attention_processor(attn_processor_cls)) + + yield None + + finally: + unet.set_attn_processor(unet_orig_processors) + exit_stack.close() + + @contextmanager + def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + exit_stack = ExitStack() + try: + changed_keys = set() + changed_unknown_keys = {} + + ordered_extensions = sorted(self.extensions, reverse=True, key=lambda ext: ext.priority) + for ext in ordered_extensions: + patch_result = exit_stack.enter_context(ext.patch_unet(state_dict, unet)) + if patch_result is None: + continue + new_keys, new_unk_keys = patch_result + changed_keys.update(new_keys) + # skip already seen keys, as new weight might be changed + for k, v in new_unk_keys.items(): + if k in changed_unknown_keys: + continue + changed_unknown_keys[k] = v + + yield None + + finally: + exit_stack.close() + assert hasattr(unet, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() + with torch.no_grad(): + for module_key in changed_keys: + weight = state_dict[module_key] + unet.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) + for module_key, weight in changed_unknown_keys.items(): + unet.get_submodule(module_key).weight.copy_( + weight, non_blocking=TorchDevice.get_non_blocking(weight.device) + ) diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py deleted file mode 100644 index 0ddcfdd3801..00000000000 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ /dev/null @@ -1,170 +0,0 @@ -from __future__ import annotations - -import copy -from dataclasses import dataclass -from typing import Any, Callable, Optional - -import torch -from diffusers.schedulers.scheduling_utils import SchedulerMixin - -from invokeai.backend.stable_diffusion.diffusers_pipeline import ( - ControlNetData, - PipelineIntermediateState, - StableDiffusionGeneratorPipeline, -) -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData -from invokeai.backend.tiles.utils import Tile - - -@dataclass -class MultiDiffusionRegionConditioning: - # Region coords in latent space. - region: Tile - text_conditioning_data: TextConditioningData - control_data: list[ControlNetData] - - -class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): - """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" - - def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]): - """Validate that regional conditioning is not used.""" - for region_conditioning in multi_diffusion_conditioning: - if ( - region_conditioning.text_conditioning_data.cond_regions is not None - or region_conditioning.text_conditioning_data.uncond_regions is not None - ): - raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") - - def multi_diffusion_denoise( - self, - multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], - target_overlap: int, - latents: torch.Tensor, - scheduler_step_kwargs: dict[str, Any], - noise: Optional[torch.Tensor], - timesteps: torch.Tensor, - init_timestep: torch.Tensor, - callback: Callable[[PipelineIntermediateState], None], - ) -> torch.Tensor: - self._check_regional_prompting(multi_diffusion_conditioning) - - if init_timestep.shape[0] == 0: - return latents - - batch_size, _, latent_height, latent_width = latents.shape - batched_init_timestep = init_timestep.expand(batch_size) - - # noise can be None if the latents have already been noised (e.g. when running the SDXL refiner). - if noise is not None: - # TODO(ryand): I'm pretty sure we should be applying init_noise_sigma in cases where we are starting with - # full noise. Investigate the history of why this got commented out. - # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers - latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) - - # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after - # cropping into regions. - self._adjust_memory_efficient_attention(latents) - - # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since - # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a - # separate scheduler state for each region batch. - # TODO(ryand): This solution allows all schedulers to **run**, but does not fully solve the issue of scheduler - # statefulness. Some schedulers store previous model outputs in their state, but these values become incorrect - # as Multi-Diffusion blending is applied (e.g. the PNDMScheduler). This can result in a blurring effect when - # multiple MultiDiffusion regions overlap. Solving this properly would require a case-by-case review of each - # scheduler to determine how it's state needs to be updated for compatibilty with Multi-Diffusion. - region_batch_schedulers: list[SchedulerMixin] = [ - copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning - ] - - callback( - PipelineIntermediateState( - step=-1, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=self.scheduler.config.num_train_timesteps, - latents=latents, - ) - ) - - for i, t in enumerate(self.progress_bar(timesteps)): - batched_t = t.expand(batch_size) - - merged_latents = torch.zeros_like(latents) - merged_latents_weights = torch.zeros( - (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype - ) - merged_pred_original: torch.Tensor | None = None - for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): - # Switch to the scheduler for the region batch. - self.scheduler = region_batch_schedulers[region_idx] - - # Crop the inputs to the region. - region_latents = latents[ - :, - :, - region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, - region_conditioning.region.coords.left : region_conditioning.region.coords.right, - ] - - # Run the denoising step on the region. - step_output = self.step( - t=batched_t, - latents=region_latents, - conditioning_data=region_conditioning.text_conditioning_data, - step_index=i, - total_step_count=len(timesteps), - scheduler_step_kwargs=scheduler_step_kwargs, - mask_guidance=None, - mask=None, - masked_latents=None, - control_data=region_conditioning.control_data, - ) - - # Store the results from the region. - # If two tiles overlap by more than the target overlap amount, crop the left and top edges of the - # affected tiles to achieve the target overlap. - region = region_conditioning.region - top_adjustment = max(0, region.overlap.top - target_overlap) - left_adjustment = max(0, region.overlap.left - target_overlap) - region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom) - region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right) - merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[ - :, :, top_adjustment:, left_adjustment: - ] - # For now, we treat every region as having the same weight. - merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0 - - pred_orig_sample = getattr(step_output, "pred_original_sample", None) - if pred_orig_sample is not None: - # If one region has pred_original_sample, then we can assume that all regions will have it, because - # they all use the same scheduler. - if merged_pred_original is None: - merged_pred_original = torch.zeros_like(latents) - merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[ - :, :, top_adjustment:, left_adjustment: - ] - - # Normalize the merged results. - latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents) - # For debugging, uncomment this line to visualize the region seams: - # latents = torch.where(merged_latents_weights > 1, 0.0, latents) - predicted_original = None - if merged_pred_original is not None: - predicted_original = torch.where( - merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original - ) - - callback( - PipelineIntermediateState( - step=i, - order=self.scheduler.order, - total_steps=len(timesteps), - timestep=int(t), - latents=latents, - predicted_original=predicted_original, - ) - ) - - return latents diff --git a/invokeai/backend/stable_diffusion/seamless.py b/invokeai/backend/stable_diffusion/seamless.py deleted file mode 100644 index 23ed978c6d0..00000000000 --- a/invokeai/backend/stable_diffusion/seamless.py +++ /dev/null @@ -1,51 +0,0 @@ -from contextlib import contextmanager -from typing import Callable, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL -from diffusers.models.autoencoders.autoencoder_tiny import AutoencoderTiny -from diffusers.models.lora import LoRACompatibleConv -from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel - - -@contextmanager -def set_seamless(model: Union[UNet2DConditionModel, AutoencoderKL, AutoencoderTiny], seamless_axes: List[str]): - if not seamless_axes: - yield - return - - # override conv_forward - # https://github.com/huggingface/diffusers/issues/556#issuecomment-1993287019 - def _conv_forward_asymmetric(self, input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None): - self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0) - self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3]) - working = torch.nn.functional.pad(input, self.paddingX, mode=x_mode) - working = torch.nn.functional.pad(working, self.paddingY, mode=y_mode) - return torch.nn.functional.conv2d( - working, weight, bias, self.stride, torch.nn.modules.utils._pair(0), self.dilation, self.groups - ) - - original_layers: List[Tuple[nn.Conv2d, Callable]] = [] - - try: - x_mode = "circular" if "x" in seamless_axes else "constant" - y_mode = "circular" if "y" in seamless_axes else "constant" - - conv_layers: List[torch.nn.Conv2d] = [] - - for module in model.modules(): - if isinstance(module, torch.nn.Conv2d): - conv_layers.append(module) - - for layer in conv_layers: - if isinstance(layer, LoRACompatibleConv) and layer.lora_layer is None: - layer.lora_layer = lambda *x: 0 - original_layers.append((layer, layer._conv_forward)) - layer._conv_forward = _conv_forward_asymmetric.__get__(layer, torch.nn.Conv2d) - - yield - - finally: - for layer, orig_conv_forward in original_layers: - layer._conv_forward = orig_conv_forward