From e046e60e1cfb3552e27d97bed0acfd6345adc0fe Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 18:31:10 +0300 Subject: [PATCH 1/4] Add FreeU support to denoise --- invokeai/app/invocations/denoise_latents.py | 9 +++- .../stable_diffusion/extensions/freeu.py | 42 +++++++++++++++++++ .../stable_diffusion/extensions_manager.py | 10 +++-- 3 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/freeu.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303cf..e043e884f90 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -58,6 +58,7 @@ from invokeai.backend.stable_diffusion.diffusion.custom_atttention import CustomAttnProcessor2_0 from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP @@ -790,18 +791,22 @@ def step_callback(state: PipelineIntermediateState) -> None: ext_manager.add_extension(PreviewExt(step_callback)) + ### freeu + if self.unet.freeu_config: + ext_manager.add_extension(FreeUExt(self.unet.freeu_config)) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) unet_info = context.models.load(self.unet.unet) assert isinstance(unet_info.model, UNet2DConditionModel) with ( - unet_info.model_on_device() as (model_state_dict, unet), + unet_info.model_on_device() as (cached_weights, unet), ModelPatcher.patch_unet_attention_processor(unet, denoise_ctx.inputs.attention_processor_cls), # ext: controlnet ext_manager.patch_extensions(unet), # ext: freeu, seamless, ip adapter, lora - ext_manager.patch_unet(model_state_dict, unet), + ext_manager.patch_unet(unet, cached_weights), ): sd_backend = StableDiffusionBackend(unet, scheduler) denoise_ctx.unet = unet diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py new file mode 100644 index 00000000000..c723aaee0b1 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Dict, Optional + +from diffusers import UNet2DConditionModel + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase + +if TYPE_CHECKING: + from invokeai.app.shared.models import FreeUConfig + + +class FreeUExt(ExtensionBase): + def __init__( + self, + freeu_config: Optional[FreeUConfig], + ): + super().__init__() + self.freeu_config = freeu_config + + @contextmanager + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): + did_apply_freeu = False + try: + assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? + if self.freeu_config is not None: + unet.enable_freeu( + b1=self.freeu_config.b1, + b2=self.freeu_config.b2, + s1=self.freeu_config.s1, + s2=self.freeu_config.s2, + ) + did_apply_freeu = True + + yield + + finally: + assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? + if did_apply_freeu: + unet.disable_freeu() diff --git a/invokeai/backend/stable_diffusion/extensions_manager.py b/invokeai/backend/stable_diffusion/extensions_manager.py index 1cae2e42190..9c4347a56c3 100644 --- a/invokeai/backend/stable_diffusion/extensions_manager.py +++ b/invokeai/backend/stable_diffusion/extensions_manager.py @@ -63,9 +63,13 @@ def patch_extensions(self, context: DenoiseContext): yield None @contextmanager - def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): if self._is_canceled and self._is_canceled(): raise CanceledException - # TODO: create logic in PR with extension which uses it - yield None + # TODO: create weight patch logic in PR with extension which uses it + with ExitStack() as exit_stack: + for ext in self._extensions: + exit_stack.enter_context(ext.patch_unet(unet, cached_weights)) + + yield None From 5772965f092645f89768fa9be0a7d174168bfc24 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 18:31:30 +0300 Subject: [PATCH 2/4] Fix slightly different output with old backend --- invokeai/backend/stable_diffusion/diffusion_backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 806deb5e03b..60a21bdc02d 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -100,8 +100,10 @@ def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] - return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) - # return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) + # Note: Although logically it same, it seams that precision errors differs. + # This sometimes results in slightly different output. + # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) + return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) def run_unet(self, ctx: DenoiseContext, ext_manager: ExtensionsManager, conditioning_mode: ConditioningMode): sample = ctx.latent_model_input From 1748848b7b13f93f58400af36a9ed4e280ce639c Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 18:37:20 +0300 Subject: [PATCH 3/4] Ruff fixes --- invokeai/backend/stable_diffusion/extensions/freeu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index c723aaee0b1..0f6c47a7733 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -3,9 +3,9 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Dict, Optional +import torch from diffusers import UNet2DConditionModel -from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: From 5f0fe3c8a986cc32ec20a40c3df56145bb0222ab Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 22 Jul 2024 23:09:11 +0300 Subject: [PATCH 4/4] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- .../stable_diffusion/diffusion_backend.py | 4 +-- .../stable_diffusion/extensions/base.py | 4 +-- .../stable_diffusion/extensions/freeu.py | 27 +++++++------------ 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 60a21bdc02d..5d0a68513fc 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -100,8 +100,8 @@ def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] - # Note: Although logically it same, it seams that precision errors differs. - # This sometimes results in slightly different output. + # Note: Although this `torch.lerp(...)` line is logically equivalent to the current CFG line, it seems to result + # in slightly different outputs. It is suspected that this is caused by small precision differences. # return torch.lerp(ctx.negative_noise_pred, ctx.positive_noise_pred, guidance_scale) return ctx.negative_noise_pred + guidance_scale * (ctx.positive_noise_pred - ctx.negative_noise_pred) diff --git a/invokeai/backend/stable_diffusion/extensions/base.py b/invokeai/backend/stable_diffusion/extensions/base.py index 802af86e6df..6a85a2e4413 100644 --- a/invokeai/backend/stable_diffusion/extensions/base.py +++ b/invokeai/backend/stable_diffusion/extensions/base.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List +from typing import TYPE_CHECKING, Callable, Dict, List, Optional import torch from diffusers import UNet2DConditionModel @@ -56,5 +56,5 @@ def patch_extension(self, context: DenoiseContext): yield None @contextmanager - def patch_unet(self, state_dict: Dict[str, torch.Tensor], unet: UNet2DConditionModel): + def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): yield None diff --git a/invokeai/backend/stable_diffusion/extensions/freeu.py b/invokeai/backend/stable_diffusion/extensions/freeu.py index 0f6c47a7733..6ec4fea3fa6 100644 --- a/invokeai/backend/stable_diffusion/extensions/freeu.py +++ b/invokeai/backend/stable_diffusion/extensions/freeu.py @@ -15,28 +15,21 @@ class FreeUExt(ExtensionBase): def __init__( self, - freeu_config: Optional[FreeUConfig], + freeu_config: FreeUConfig, ): super().__init__() - self.freeu_config = freeu_config + self._freeu_config = freeu_config @contextmanager def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - did_apply_freeu = False - try: - assert hasattr(unet, "enable_freeu") # mypy doesn't pick up this attribute? - if self.freeu_config is not None: - unet.enable_freeu( - b1=self.freeu_config.b1, - b2=self.freeu_config.b2, - s1=self.freeu_config.s1, - s2=self.freeu_config.s2, - ) - did_apply_freeu = True + unet.enable_freeu( + b1=self._freeu_config.b1, + b2=self._freeu_config.b2, + s1=self._freeu_config.s1, + s2=self._freeu_config.s2, + ) + try: yield - finally: - assert hasattr(unet, "disable_freeu") # mypy doesn't pick up this attribute? - if did_apply_freeu: - unet.disable_freeu() + unet.disable_freeu()