From 9a1420280e3e5dc3ba636c13c6c6ef5094a56cfa Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 21 Jul 2024 17:33:43 +0300 Subject: [PATCH 1/2] Add rescale cfg support to denoise --- invokeai/app/invocations/denoise_latents.py | 5 +++ .../stable_diffusion/diffusion_backend.py | 8 ++--- .../extension_callback_type.py | 2 +- .../extensions/rescale_cfg.py | 36 +++++++++++++++++++ 4 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 invokeai/backend/stable_diffusion/extensions/rescale_cfg.py diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index ccacc3303cf..68deb68445c 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -59,6 +59,7 @@ from invokeai.backend.stable_diffusion.diffusion_backend import StableDiffusionBackend from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt +from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager from invokeai.backend.stable_diffusion.schedulers import SCHEDULER_MAP from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES @@ -790,6 +791,10 @@ def step_callback(state: PipelineIntermediateState) -> None: ext_manager.add_extension(PreviewExt(step_callback)) + ### cfg rescale + if self.cfg_rescale_multiplier > 0: + ext_manager.add_extension(RescaleCFGExt(self.cfg_rescale_multiplier)) + # ext: t2i/ip adapter ext_manager.run_callback(ExtensionCallbackType.SETUP, denoise_ctx) diff --git a/invokeai/backend/stable_diffusion/diffusion_backend.py b/invokeai/backend/stable_diffusion/diffusion_backend.py index 806deb5e03b..d41aa63c606 100644 --- a/invokeai/backend/stable_diffusion/diffusion_backend.py +++ b/invokeai/backend/stable_diffusion/diffusion_backend.py @@ -76,12 +76,12 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler both_noise_pred = self.run_unet(ctx, ext_manager, ConditioningMode.Both) ctx.negative_noise_pred, ctx.positive_noise_pred = both_noise_pred.chunk(2) - # ext: override apply_cfg - ctx.noise_pred = self.apply_cfg(ctx) + # ext: override combine_noise_preds + ctx.noise_pred = self.combine_noise_preds(ctx) # ext: cfg_rescale [modify_noise_prediction] # TODO: rename - ext_manager.run_callback(ExtensionCallbackType.POST_APPLY_CFG, ctx) + ext_manager.run_callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS, ctx) # compute the previous noisy sample x_t -> x_t-1 step_output = ctx.scheduler.step(ctx.noise_pred, ctx.timestep, ctx.latents, **ctx.inputs.scheduler_step_kwargs) @@ -95,7 +95,7 @@ def step(self, ctx: DenoiseContext, ext_manager: ExtensionsManager) -> Scheduler return step_output @staticmethod - def apply_cfg(ctx: DenoiseContext) -> torch.Tensor: + def combine_noise_preds(ctx: DenoiseContext) -> torch.Tensor: guidance_scale = ctx.inputs.conditioning_data.guidance_scale if isinstance(guidance_scale, list): guidance_scale = guidance_scale[ctx.step_index] diff --git a/invokeai/backend/stable_diffusion/extension_callback_type.py b/invokeai/backend/stable_diffusion/extension_callback_type.py index aaefbd7ed04..e4c365007ba 100644 --- a/invokeai/backend/stable_diffusion/extension_callback_type.py +++ b/invokeai/backend/stable_diffusion/extension_callback_type.py @@ -9,4 +9,4 @@ class ExtensionCallbackType(Enum): POST_STEP = "post_step" PRE_UNET = "pre_unet" POST_UNET = "post_unet" - POST_APPLY_CFG = "post_apply_cfg" + POST_COMBINE_NOISE_PREDS = "post_combine_noise_preds" diff --git a/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py b/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py new file mode 100644 index 00000000000..51fad975e78 --- /dev/null +++ b/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch + +from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType +from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback + +if TYPE_CHECKING: + from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext + + +class RescaleCFGExt(ExtensionBase): + def __init__(self, rescale_multiplier: float): + super().__init__() + self.rescale_multiplier = rescale_multiplier + + @staticmethod + def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7): + """Implementation of Algorithm 2 from https://arxiv.org/pdf/2305.08891.pdf.""" + ro_pos = torch.std(pos_noise_pred, dim=(1, 2, 3), keepdim=True) + ro_cfg = torch.std(total_noise_pred, dim=(1, 2, 3), keepdim=True) + + x_rescaled = total_noise_pred * (ro_pos / ro_cfg) + x_final = multiplier * x_rescaled + (1.0 - multiplier) * total_noise_pred + return x_final + + @callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS) + def rescale_noise_pred(self, ctx: DenoiseContext): + if self.rescale_multiplier > 0: + ctx.noise_pred = self._rescale_cfg( + ctx.noise_pred, + ctx.positive_noise_pred, + self.rescale_multiplier, + ) From 1b359b55cb57c7ad9de3a5a026f393f3f82fc7e2 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Mon, 22 Jul 2024 22:17:29 +0300 Subject: [PATCH 2/2] Suggested changes Co-Authored-By: Ryan Dick <14897797+RyanJDick@users.noreply.github.com> --- .../stable_diffusion/denoise_context.py | 20 +++++++++---------- .../extensions/rescale_cfg.py | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/invokeai/backend/stable_diffusion/denoise_context.py b/invokeai/backend/stable_diffusion/denoise_context.py index 2b43d3fb0f2..9060d549776 100644 --- a/invokeai/backend/stable_diffusion/denoise_context.py +++ b/invokeai/backend/stable_diffusion/denoise_context.py @@ -83,47 +83,47 @@ class DenoiseContext: unet: Optional[UNet2DConditionModel] = None # Current state of latent-space image in denoising process. - # None until `pre_denoise_loop` callback. + # None until `PRE_DENOISE_LOOP` callback. # Shape: [batch, channels, latent_height, latent_width] latents: Optional[torch.Tensor] = None # Current denoising step index. - # None until `pre_step` callback. + # None until `PRE_STEP` callback. step_index: Optional[int] = None # Current denoising step timestep. - # None until `pre_step` callback. + # None until `PRE_STEP` callback. timestep: Optional[torch.Tensor] = None # Arguments which will be passed to UNet model. - # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None. unet_kwargs: Optional[UNetKwargs] = None # SchedulerOutput class returned from step function(normally, generated by scheduler). - # Supposed to be used only in `post_step` callback, otherwise can be None. + # Supposed to be used only in `POST_STEP` callback, otherwise can be None. step_output: Optional[SchedulerOutput] = None # Scaled version of `latents`, which will be passed to unet_kwargs initialization. - # Available in events inside step(between `pre_step` and `post_stop`). + # Available in events inside step(between `PRE_STEP` and `POST_STEP`). # Shape: [batch, channels, latent_height, latent_width] latent_model_input: Optional[torch.Tensor] = None # [TMP] Defines on which conditionings current unet call will be runned. - # Available in `pre_unet`/`post_unet` callbacks, otherwise will be None. + # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None. conditioning_mode: Optional[ConditioningMode] = None # [TMP] Noise predictions from negative conditioning. - # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None. # Shape: [batch, channels, latent_height, latent_width] negative_noise_pred: Optional[torch.Tensor] = None # [TMP] Noise predictions from positive conditioning. - # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None. # Shape: [batch, channels, latent_height, latent_width] positive_noise_pred: Optional[torch.Tensor] = None # Combined noise prediction from passed conditionings. - # Available in `apply_cfg` and `post_apply_cfg` callbacks, otherwise will be None. + # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None. # Shape: [batch, channels, latent_height, latent_width] noise_pred: Optional[torch.Tensor] = None diff --git a/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py b/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py index 51fad975e78..7cccbb8a2bc 100644 --- a/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py +++ b/invokeai/backend/stable_diffusion/extensions/rescale_cfg.py @@ -14,7 +14,7 @@ class RescaleCFGExt(ExtensionBase): def __init__(self, rescale_multiplier: float): super().__init__() - self.rescale_multiplier = rescale_multiplier + self._rescale_multiplier = rescale_multiplier @staticmethod def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, multiplier: float = 0.7): @@ -28,9 +28,9 @@ def _rescale_cfg(total_noise_pred: torch.Tensor, pos_noise_pred: torch.Tensor, m @callback(ExtensionCallbackType.POST_COMBINE_NOISE_PREDS) def rescale_noise_pred(self, ctx: DenoiseContext): - if self.rescale_multiplier > 0: + if self._rescale_multiplier > 0: ctx.noise_pred = self._rescale_cfg( ctx.noise_pred, ctx.positive_noise_pred, - self.rescale_multiplier, + self._rescale_multiplier, )