diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 68deb68445c..c79e1ed6975 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager @@ -795,18 +796,22 @@ def step_callback(state: PipelineIntermediateState) -> None: if self.cfg_rescale_multiplier > 0: ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier)) + ### freeu + if self.unet.freeu_config: + ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) with ( - unet_info.model_on_device() as (model_state_dict, unet), + unet_info.model_on_device() as (cached_weights, unet), ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), # ext: controlnet ext_manager.patch_extensions(unet), # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), + ext_manager.patch_unet(unet, cached_weights), ): sd_backend = StableDiffusionBackend(unet, scheduler) denoise_ctx.unet = unet diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index d41aa63c606..4191db734f9 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -100,8 +100,10 @@ def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor: if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] - return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) - # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) + # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result + # in slightly different outputs. It is suspected that this is caused by small precision differences. + # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) + return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): sample = ctx.latent_model_input diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 802af86e6df..6a85a2e4413 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -56,5 +56,5 @@ def patch_extension(self, context: DenoiseContext): yield None @contextmanager - def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): yield None diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py new file mode 100644 index 00000000000..6ec4fea3fa6 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Optional + +import torch +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + +if TYPE_CHECKING: + from invokeai.app.shared.models import FreeUConfig + + +class FreeUExt(ExtensionBase): + def __init__( + self, + freeu_config: FreeUConfig, + ): + super().__init__() + self._freeu_config = freeu_config + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + unet.enable_freeu( + b1=self._freeu_config.b1, + b2=self._freeu_config.b2, + s1=self._freeu_config.s1, + s2=self._freeu_config.s2, + ) + + try: + yield + finally: + unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1cae2e42190..9c4347a56c3 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -63,9 +63,13 @@ def patch_extensions(self, context: DenoiseContext): yield None @contextmanager - def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): if self._is_canceled and self._is_canceled(): raise CanceledException - # TODO: create logic in PR with extension which uses it - yield None + # TODO: create weight patch logic in PR with extension which uses it + with ExitStack() as exit_stack: + for ext in self._extensions: + exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + + yield None