Skip to content

Commit

Permalink
Partial Loading PR4: Enable partial loading (behind config flag) (#7505)
Browse files Browse the repository at this point in the history
## Summary

This PR adds support for partial loading of models onto the GPU. This
enables models to run with much lower peak VRAM requirements (e.g. full
FLUX dev with 8GB of VRAM).

The partial loading feature is enabled behind a new config flag:
`enable_partial_loading=True`. This flag defaults to `False`.

**Note about performance:**
The `ram` and `vram` config limits are still applied when
`enable_partial_loading=True` is set. This can result in significant
slowdowns compared to the 'old' behaviour. Consider the case where the
VRAM limit is set to `vram=0.75` (GB) and we are trying to run an 8GB
model. When `enable_partial_loading=False`, we attempt to load the
entire model into VRAM, and if it fits (no OOM error) then it will run
at full speed. When `enable_partial_loading=True`, since we have the
option to partially load the model we will only load 0.75 GB into VRAM
and leave the remaining 7.25 GB in RAM. This will cause inference to be
much slower than before. To workaround this, it is important that your
`ram` and `vram` configs are carefully tuned. In a future PR, we will
add the ability to dynamically set the RAM/VRAM limits based on the
available memory / VRAM.

## Related Issues / Discussions

- #7492 
- #7494 
- #7500

## QA Instructions

Tests with `enable_partial_loading=True`, `vram=2`, on CUDA device:
For all tests, we expect model memory to stay below 2 GB. Peak working
memory will be higher.
- [x] SD1 inference
- [x] SDXL inference
- [x] FLUX non-quantized inference
- [x] FLUX GGML-quantized inference
- [x] FLUX BnB quantized inference
- [x] Variety of ControlNet / IP-Adapter / LoRA smoke tests

Tests with `enable_partial_loading=True`, and hack to force all models
to load 10%, on CUDA device:
- [x] SD1 inference
- [x] SDXL inference
- [x] FLUX non-quantized inference
- [x] FLUX GGML-quantized inference
- [x] FLUX BnB quantized inference
- [x] Variety of ControlNet / IP-Adapter / LoRA smoke tests

Tests with `enable_partial_loading=False`, `vram=30`:
We expect no change in behaviour when  `enable_partial_loading=False`.
- [x] SD1 inference
- [x] SDXL inference
- [x] FLUX non-quantized inference
- [x] FLUX GGML-quantized inference
- [x] FLUX BnB quantized inference
- [x] Variety of ControlNet / IP-Adapter / LoRA smoke tests

Other platforms:
- [x] No change in behavior on MPS, even if
`enable_partial_loading=True`.
- [x] No change in behavior on CPU-only systems, even if
`enable_partial_loading=True`.

## Merge Plan

- [x] Merge #7500 first, and change the target branch to main

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 7, 2025
2 parents 782ee7a + 6a9de1f commit 87fdcb7
Show file tree
Hide file tree
Showing 23 changed files with 396 additions and 292 deletions.
2 changes: 2 additions & 0 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
textual_inversion_manager=ti_manager,
dtype_for_device_getter=TorchDevice.choose_torch_dtype,
truncate_long_prompts=False,
device=TorchDevice.choose_torch_device(),
)

conjunction = Compel.parse_prompt_string(self.prompt)
Expand Down Expand Up @@ -207,6 +208,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]:
truncate_long_prompts=False, # TODO:
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip
requires_pooled=get_pooled,
device=TorchDevice.choose_torch_device(),
)

conjunction = Compel.parse_prompt_string(prompt)
Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
latents = latents.to(TorchDevice.choose_torch_device())
if self.fp32:
vae.to(dtype=torch.float32)

Expand Down
2 changes: 1 addition & 1 deletion invokeai/app/invocations/sd3_latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)
latents = latents.to(TorchDevice.choose_torch_device())

vae.disable_tiling()

Expand Down
5 changes: 3 additions & 2 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo
from invokeai.backend.util.devices import TorchDevice

# The SD3 T5 Max Sequence Length set based on the default in diffusers.
SD3_T5_MAX_SEQ_LEN = 256
Expand Down Expand Up @@ -120,7 +121,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
f" {max_seq_len} tokens: {removed_text}"
)

prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0]
prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0]

assert isinstance(prompt_embeds, torch.Tensor)
return prompt_embeds
Expand Down Expand Up @@ -185,7 +186,7 @@ def _clip_encode(
f" {tokenizer_max_length} tokens: {removed_text}"
)
prompt_embeds = clip_text_encoder(
input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True
input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True
)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: Amount of VRAM reserved for model storage (GB).
lazy_offload: Keep models in VRAM until their space is needed.
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.
device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.<br>Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps`
precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.<br>Valid values: `auto`, `float16`, `bfloat16`, `float32`
sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements.
Expand Down Expand Up @@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings):
vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).")
lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.")
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.")

# DEVICE
device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.")
Expand Down
6 changes: 3 additions & 3 deletions invokeai/app/services/model_manager/model_manager_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def build_model_manager(
logger.setLevel(app_config.log_level.upper())

ram_cache = ModelCache(
max_cache_size=app_config.ram,
max_vram_cache_size=app_config.vram,
lazy_offloading=app_config.lazy_offload,
max_ram_cache_size_gb=app_config.ram,
max_vram_cache_size_gb=app_config.vram,
enable_partial_loading=app_config.enable_partial_loading,
logger=logger,
execution_device=execution_device or TorchDevice.choose_torch_device(),
)
Expand Down
4 changes: 3 additions & 1 deletion invokeai/backend/flux/modules/conditioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from torch import Tensor, nn
from transformers import PreTrainedModel, PreTrainedTokenizer

from invokeai.backend.util.devices import TorchDevice


class HFEncoder(nn.Module):
def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int):
Expand All @@ -26,7 +28,7 @@ def forward(self, text: list[str]) -> Tensor:
)

outputs = self.hf_module(
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()),
attention_mask=None,
output_hidden_states=False,
)
Expand Down
5 changes: 3 additions & 2 deletions invokeai/backend/image_util/hed.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
resize_image_to_resolution,
safe_step,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class DoubleConvBlock(torch.nn.Module):
Expand Down Expand Up @@ -109,7 +110,7 @@ def run(
Returns:
The detected edges.
"""
device = next(iter(self.network.parameters())).device
device = get_effective_device(self.network)
np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
np_image = resize_image_to_resolution(np_image, detect_resolution)
Expand Down Expand Up @@ -183,7 +184,7 @@ def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) ->
The detected edges.
"""

device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(image)

Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/image_util/infill_methods/lama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import invokeai.backend.util.logging as logger
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


def norm_img(np_img):
Expand All @@ -31,7 +32,7 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any:
mask = norm_img(mask)
mask = (mask > 0) * 1

device = next(self._model.buffers()).device
device = get_effective_device(self._model)
image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device)

Expand Down
5 changes: 3 additions & 2 deletions invokeai/backend/image_util/lineart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class ResidualBlock(nn.Module):
Expand Down Expand Up @@ -130,7 +131,7 @@ def run(
Returns:
The detected lineart.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(input_image)
np_image = normalize_image_channel_count(np_image)
Expand Down Expand Up @@ -201,7 +202,7 @@ def run(self, image: Image.Image) -> Image.Image:
Returns:
The detected edges.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(image)

Expand Down
5 changes: 3 additions & 2 deletions invokeai/backend/image_util/lineart_anime.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
pil_to_np,
resize_image_to_resolution,
)
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class UnetGenerator(nn.Module):
Expand Down Expand Up @@ -171,7 +172,7 @@ def run(self, input_image: Image.Image, detect_resolution: int = 512, image_reso
Returns:
The detected lineart.
"""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(input_image)

np_image = normalize_image_channel_count(np_image)
Expand Down Expand Up @@ -239,7 +240,7 @@ def to(self, device: torch.device):

def run(self, image: Image.Image) -> Image.Image:
"""Processes an image and returns the detected edges."""
device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_image = pil_to_np(image)

Expand Down
6 changes: 4 additions & 2 deletions invokeai/backend/image_util/mlsd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import torch
from torch.nn import functional as F

from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5):
'''
Expand Down Expand Up @@ -49,7 +51,7 @@ def pred_lines(image, model,
dist_thr=20.0):
h, w, _ = image.shape

device = next(iter(model.parameters())).device
device = get_effective_device(model)
h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]]

resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA),
Expand Down Expand Up @@ -108,7 +110,7 @@ def pred_squares(image,
'''
h, w, _ = image.shape
original_shape = [h, w]
device = next(iter(model.parameters())).device
device = get_effective_device(model)

resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA),
np.ones([input_shape[0], input_shape[1], 1])], axis=-1)
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/image_util/normal_bae/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from invokeai.backend.image_util.normal_bae.nets.NNET import NNET
from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class NormalMapDetector:
Expand Down Expand Up @@ -64,7 +65,7 @@ def to(self, device: torch.device):
def run(self, image: Image.Image):
"""Processes an image and returns the detected normal map."""

device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)
np_image = pil_to_np(image)

height, width, _channels = np_image.shape
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/image_util/pidi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet
from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step
from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device


class PIDINetDetector:
Expand Down Expand Up @@ -45,7 +46,7 @@ def run(
) -> Image.Image:
"""Processes an image and returns the detected edges."""

device = next(iter(self.model.parameters())).device
device = get_effective_device(self.model)

np_img = pil_to_np(image)
np_img = normalize_image_channel_count(np_img)
Expand Down
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/load/load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]],
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record)
try:
yield (self._cache_record.state_dict, self._cache_record.model)
yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model)
finally:
self._cache.unlock(self._cache_record)

@property
def model(self) -> AnyModel:
"""Return the model without locking it."""
return self._cache_record.model
return self._cache_record.cached_model.model


class LoadedModel(LoadedModelWithoutConfig):
Expand Down
39 changes: 11 additions & 28 deletions invokeai/backend/model_manager/load/model_cache/cache_record.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,21 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional

import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
CachedModelWithPartialLoad,
)


@dataclass
class CacheRecord:
"""
Elements of the cache:
key: Unique key for each model, same as used in the models database.
model: Model in memory.
state_dict: A read-only copy of the model's state dict in RAM. It will be
used as a template for creating a copy in the VRAM.
size: Size of the model
loaded: True if the model's state dict is currently in VRAM
Before a model is executed, the state_dict template is copied into VRAM,
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
"""A class that represents a model in the model cache."""

# Cache key.
key: str
model: Any
device: torch.device
state_dict: Optional[Dict[str, torch.Tensor]]
size: int
loaded: bool = False
# Model in memory.
cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad
_locks: int = 0

def lock(self) -> None:
Expand All @@ -45,6 +28,6 @@ def unlock(self) -> None:
assert self._locks >= 0

@property
def locked(self) -> bool:
def is_locked(self) -> bool:
"""Return true if record is locked."""
return self._locks > 0
Loading

0 comments on commit 87fdcb7

Please sign in to comment.