Skip to content

Commit

Permalink
We should not trust the value of since the model could be partially-l…
Browse files Browse the repository at this point in the history
…oaded.
  • Loading branch information
RyanJDick committed Jan 7, 2025
1 parent 6b18f27 commit 5d36c1c
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 24 deletions.
21 changes: 13 additions & 8 deletions invokeai/app/invocations/denoise_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ def prep_control_data(
context: InvocationContext,
control_input: ControlField | list[ControlField] | None,
latents_shape: List[int],
device: torch.device,
exit_stack: ExitStack,
do_classifier_free_guidance: bool = True,
) -> list[ControlNetData] | None:
Expand Down Expand Up @@ -452,7 +453,7 @@ def prep_control_data(
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
device=device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
Expand Down Expand Up @@ -605,6 +606,7 @@ def run_t2i_adapters(
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
device: torch.device,
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
Expand Down Expand Up @@ -655,7 +657,7 @@ def run_t2i_adapters(
width=control_width_resize,
height=control_height_resize,
num_channels=t2i_adapter_model.config["in_channels"], # mypy treats this as a FrozenDict
device=t2i_adapter_model.device,
device=device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
Expand Down Expand Up @@ -946,6 +948,7 @@ def step_callback(state: PipelineIntermediateState) -> None:
@torch.no_grad()
@SilenceWarnings() # This quenches the NSFW nag from diffusers.
def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
device = TorchDevice.choose_torch_device()
seed, noise, latents = self.prepare_noise_and_latents(context, self.noise, self.latents)

mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents)
Expand All @@ -960,6 +963,7 @@ def _old_invoke(self, context: InvocationContext) -> LatentsOutput:
context,
self.t2i_adapter,
latents.shape,
device=device,
do_classifier_free_guidance=True,
)

Expand Down Expand Up @@ -1006,13 +1010,13 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
noise = noise.to(device=device, dtype=unet.dtype)
if mask is not None:
mask = mask.to(device=unet.device, dtype=unet.dtype)
mask = mask.to(device=device, dtype=unet.dtype)
if masked_latents is not None:
masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype)
masked_latents = masked_latents.to(device=device, dtype=unet.dtype)

scheduler = get_scheduler(
context=context,
Expand All @@ -1028,7 +1032,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
device=device,
dtype=unet.dtype,
latent_height=latent_height,
latent_width=latent_width,
Expand All @@ -1041,6 +1045,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
context=context,
control_input=self.control,
latents_shape=latents.shape,
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
Expand All @@ -1058,7 +1063,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:

timesteps, init_timestep, scheduler_step_kwargs = self.init_scheduler(
scheduler,
device=unet.device,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
Expand Down
7 changes: 4 additions & 3 deletions invokeai/app/invocations/flux_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def _run_diffusion(
# TODO(ryand): We should really do this in a separate invocation to benefit from caching.
ip_adapter_fields = self._normalize_ip_adapter_fields()
pos_image_prompt_clip_embeds, neg_image_prompt_clip_embeds = self._prep_ip_adapter_image_prompt_clip_embeds(
ip_adapter_fields, context
ip_adapter_fields, context, device=x.device
)

cfg_scale = self.prep_cfg_scale(
Expand Down Expand Up @@ -626,6 +626,7 @@ def _prep_ip_adapter_image_prompt_clip_embeds(
self,
ip_adapter_fields: list[IPAdapterField],
context: InvocationContext,
device: torch.device,
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""Run the IPAdapter CLIPVisionModel, returning image prompt embeddings."""
clip_image_processor = CLIPImageProcessor()
Expand Down Expand Up @@ -665,11 +666,11 @@ def _prep_ip_adapter_image_prompt_clip_embeds(
assert isinstance(image_encoder_model, CLIPVisionModelWithProjection)

clip_image: torch.Tensor = clip_image_processor(images=pos_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
pos_clip_image_embeds = image_encoder_model(clip_image).image_embeds

clip_image = clip_image_processor(images=neg_images, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder_model.device, dtype=image_encoder_model.dtype)
clip_image = clip_image.to(device=device, dtype=image_encoder_model.dtype)
neg_clip_image_embeds = image_encoder_model(clip_image).image_embeds

pos_image_prompt_clip_embeds.append(pos_clip_image_embeds)
Expand Down
3 changes: 2 additions & 1 deletion invokeai/app/invocations/image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from invokeai.backend.model_manager import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.stable_diffusion.vae_tiling import patch_vae_tiling_params
from invokeai.backend.util.devices import TorchDevice


@invocation(
Expand Down Expand Up @@ -98,7 +99,7 @@ def vae_encode(
)

# non_noised_latents_from_image
image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode(), tiling_context:
latents = ImageToLatentsInvocation._encode_to_tensor(vae, image_tensor)

Expand Down
3 changes: 2 additions & 1 deletion invokeai/app/invocations/sd3_image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice


@invocation(
Expand All @@ -39,7 +40,7 @@ def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tenso

vae.disable_tiling()

image_tensor = image_tensor.to(device=vae.device, dtype=vae.dtype)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae.dtype)
with torch.inference_mode():
image_tensor_dist = vae.encode(image_tensor).latent_dist
# TODO: Use seed to make sampling reproducible.
Expand Down
7 changes: 3 additions & 4 deletions invokeai/app/invocations/spandrel_image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
from invokeai.backend.util.devices import TorchDevice


@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0")
Expand Down Expand Up @@ -102,7 +103,7 @@ def upscale_image(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)

image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=spandrel_model.dtype)

# Run the model on each tile.
pbar = tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles")
Expand All @@ -116,9 +117,7 @@ def upscale_image(
raise CanceledException

# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
input_tile = image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]

# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
yield (lora_info.model, lora.weight)
del lora_info

device = TorchDevice.choose_torch_device()
with (
ExitStack() as exit_stack,
context.models.load(self.unet.unet) as unet,
Expand All @@ -209,9 +210,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
),
):
assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
noise = noise.to(device=device, dtype=unet.dtype)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
Expand All @@ -225,7 +226,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
device=unet.device,
device=device,
dtype=unet.dtype,
latent_height=latent_tile_height,
latent_width=latent_tile_width,
Expand All @@ -238,6 +239,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
context=context,
control_input=self.control,
latents_shape=list(latents.shape),
device=device,
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
Expand All @@ -263,7 +265,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:

timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
device=device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
from invokeai.backend.util.devices import TorchDevice


class XLabsIPAdapterExtension:
Expand Down Expand Up @@ -45,7 +46,7 @@ def run_clip_image_encoder(
) -> torch.Tensor:
clip_image_processor = CLIPImageProcessor()
clip_image: torch.Tensor = clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(device=image_encoder.device, dtype=image_encoder.dtype)
clip_image = clip_image.to(device=TorchDevice.choose_torch_device(), dtype=image_encoder.dtype)
clip_image_embeds = image_encoder(clip_image).image_embeds
return clip_image_embeds

Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from invokeai.app.shared.models import FreeUConfig
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
from invokeai.backend.util.devices import TorchDevice


class ModelPatcher:
Expand Down Expand Up @@ -122,7 +123,7 @@ def _get_ti_embedding(model_embeddings: torch.nn.Module, ti: TextualInversionMod
)

model_embeddings.weight.data[token_id] = embedding.to(
device=text_encoder.device, dtype=text_encoder.dtype
device=TorchDevice.choose_torch_device(), dtype=text_encoder.dtype
)
ti_tokens.append(token_id)

Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/stable_diffusion/extensions/t2i_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase, callback
from invokeai.backend.util.devices import TorchDevice

if TYPE_CHECKING:
from invokeai.app.invocations.model import ModelIdentifierField
Expand Down Expand Up @@ -89,7 +90,7 @@ def _run_model(
width=input_width,
height=input_height,
num_channels=model.config["in_channels"],
device=model.device,
device=TorchDevice.choose_torch_device(),
dtype=model.dtype,
resize_mode=self._resize_mode,
)
Expand Down

0 comments on commit 5d36c1c

Please sign in to comment.