From 535e45cedfc1884bdac0310dd29f170fcd4ce4f8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 18:43:00 +0000 Subject: [PATCH 01/10] First pass at adding partial loading support to the ModelCache. --- .../model_manager/model_manager_default.py | 5 +- .../backend/model_manager/load/load_base.py | 4 +- .../load/model_cache/cache_record.py | 39 +- .../load/model_cache/model_cache.py | 499 ++++++++++-------- .../model_manager/model_manager_fixtures.py | 4 +- 5 files changed, 297 insertions(+), 254 deletions(-) diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index a05456eb8a2..bdd1f5da437 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -82,9 +82,8 @@ def build_model_manager( logger.setLevel(app_config.log_level.upper()) ram_cache = ModelCache( - max_cache_size=app_config.ram, - max_vram_cache_size=app_config.vram, - lazy_offloading=app_config.lazy_offload, + max_ram_cache_size_gb=app_config.ram, + max_vram_cache_size_gb=app_config.vram, logger=logger, execution_device=execution_device or TorchDevice.choose_torch_device(), ) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index d62db363a6d..1bf24edeed9 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -68,14 +68,14 @@ def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], """Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.""" self._cache.lock(self._cache_record) try: - yield (self._cache_record.state_dict, self._cache_record.model) + yield (self._cache_record.cached_model.get_cpu_state_dict(), self._cache_record.cached_model.model) finally: self._cache.unlock(self._cache_record) @property def model(self) -> AnyModel: """Return the model without locking it.""" - return self._cache_record.model + return self._cache_record.cached_model.model class LoadedModel(LoadedModelWithoutConfig): diff --git a/invokeai/backend/model_manager/load/model_cache/cache_record.py b/invokeai/backend/model_manager/load/model_cache/cache_record.py index dfa8aeb3f2e..c48435d0ef0 100644 --- a/invokeai/backend/model_manager/load/model_cache/cache_record.py +++ b/invokeai/backend/model_manager/load/model_cache/cache_record.py @@ -1,38 +1,21 @@ from dataclasses import dataclass -from typing import Any, Dict, Optional -import torch +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) @dataclass class CacheRecord: - """ - Elements of the cache: - - key: Unique key for each model, same as used in the models database. - model: Model in memory. - state_dict: A read-only copy of the model's state dict in RAM. It will be - used as a template for creating a copy in the VRAM. - size: Size of the model - loaded: True if the model's state dict is currently in VRAM - - Before a model is executed, the state_dict template is copied into VRAM, - and then injected into the model. When the model is finished, the VRAM - copy of the state dict is deleted, and the RAM version is reinjected - into the model. - - The state_dict should be treated as a read-only attribute. Do not attempt - to patch or otherwise modify it. Instead, patch the copy of the state_dict - after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` - context manager call `model_on_device()`. - """ + """A class that represents a model in the model cache.""" + # Cache key. key: str - model: Any - device: torch.device - state_dict: Optional[Dict[str, torch.Tensor]] - size: int - loaded: bool = False + # Model in memory. + cached_model: CachedModelWithPartialLoad | CachedModelOnlyFullLoad _locks: int = 0 def lock(self) -> None: @@ -45,6 +28,6 @@ def unlock(self) -> None: assert self._locks >= 0 @property - def locked(self) -> bool: + def is_locked(self) -> bool: """Return true if record is locked.""" return self._locks > 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index cd296aa7bd7..ecf3ffa6234 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -1,8 +1,5 @@ -# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team -# TODO: Add Stalker's proper name to copyright - import gc -import math +import logging import time from logging import Logger from typing import Dict, List, Optional @@ -10,9 +7,15 @@ import torch from invokeai.backend.model_manager import AnyModel, SubModelType -from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff +from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( + CachedModelOnlyFullLoad, +) +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( apply_custom_layers_to_model, ) @@ -29,6 +32,7 @@ # TODO(ryand): Where should this go? The ModelCache shouldn't be concerned with submodels. def get_model_cache_key(model_key: str, submodel_type: Optional[SubModelType] = None) -> str: + """Get the cache key for a model based on the optional submodel type.""" if submodel_type: return f"{model_key}:{submodel_type.value}" else: @@ -70,34 +74,35 @@ class ModelCache: def __init__( self, - max_cache_size: float, - max_vram_cache_size: float, - execution_device: torch.device = torch.device("cuda"), - storage_device: torch.device = torch.device("cpu"), - lazy_offloading: bool = True, + max_ram_cache_size_gb: float, + max_vram_cache_size_gb: float, + execution_device: torch.device | str = "cuda", + storage_device: torch.device | str = "cpu", log_memory_usage: bool = False, logger: Optional[Logger] = None, ): - """ - Initialize the model RAM cache. - - :param max_cache_size: Maximum size of the storage_device cache in GBs. - :param max_vram_cache_size: Maximum size of the execution_device cache in GBs. + """Initialize the model RAM cache. + + :param max_ram_cache_size_gb: The maximum amount of CPU RAM to use for model caching in GB. This parameter is + kept to maintain compatibility with previous versions of the model cache, but should be deprecated in the + future. If set, this parameter overrides the default cache size logic. + :param max_vram_cache_size_gb: The amount of VRAM to use for model caching in GB. This parameter is kept to + maintain compatibility with previous versions of the model cache, but should be deprecated in the future. + If set, this parameter overrides the default cache size logic. :param execution_device: Torch device to load active model into [torch.device('cuda')] :param storage_device: Torch device to save inactive model in [torch.device('cpu')] - :param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's behaviour. :param logger: InvokeAILogger to use (otherwise creates one) """ - # allow lazy offloading only when vram cache enabled - self._lazy_offloading = lazy_offloading and max_vram_cache_size > 0 - self._max_cache_size: float = max_cache_size - self._max_vram_cache_size: float = max_vram_cache_size - self._execution_device: torch.device = execution_device - self._storage_device: torch.device = storage_device + self._execution_device: torch.device = torch.device(execution_device) + self._storage_device: torch.device = torch.device(storage_device) + + self._max_ram_cache_size_gb = max_ram_cache_size_gb + self._max_vram_cache_size_gb = max_vram_cache_size_gb + self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage self._stats: Optional[CacheStats] = None @@ -105,26 +110,6 @@ def __init__( self._cached_models: Dict[str, CacheRecord] = {} self._cache_stack: List[str] = [] - @property - def max_cache_size(self) -> float: - """Return the cap on cache size.""" - return self._max_cache_size - - @max_cache_size.setter - def max_cache_size(self, value: float) -> None: - """Set the cap on cache size.""" - self._max_cache_size = value - - @property - def max_vram_cache_size(self) -> float: - """Return the cap on vram cache size.""" - return self._max_vram_cache_size - - @max_vram_cache_size.setter - def max_vram_cache_size(self, value: float) -> None: - """Set the cap on vram cache size.""" - self._max_vram_cache_size = value - @property def stats(self) -> Optional[CacheStats]: """Return collected CacheStats object.""" @@ -132,17 +117,17 @@ def stats(self) -> Optional[CacheStats]: @stats.setter def stats(self, stats: CacheStats) -> None: - """Set the CacheStats object for collectin cache statistics.""" + """Set the CacheStats object for collecting cache statistics.""" self._stats = stats - def put( - self, - key: str, - model: AnyModel, - ) -> None: - """Insert model into the cache.""" + def put(self, key: str, model: AnyModel) -> None: + """Add a model to the cache.""" if key in self._cached_models: + self._logger.debug( + f"Attempted to add model {key} ({model.__class__.__name__}), but it already exists in the cache. No action necessary." + ) return + size = calc_model_size_by_data(self._logger, model) self.make_room(size) @@ -150,17 +135,26 @@ def put( if isinstance(model, torch.nn.Module): apply_custom_layers_to_model(model) - running_on_cpu = self._execution_device == torch.device("cpu") - state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None - cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size) + # Partial loading only makes sense on CUDA. + # - When running on CPU, there is no 'loading' to do. + # - When running on MPS, memory is shared with the CPU, so the default OS memory management already handles this + # well. + running_with_cuda = self._execution_device.type == "cuda" + + # Wrap model. + if isinstance(model, torch.nn.Module) and running_with_cuda: + wrapped_model = CachedModelWithPartialLoad(model, self._execution_device) + else: + wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size) + + cache_record = CacheRecord(key=key, cached_model=wrapped_model) self._cached_models[key] = cache_record self._cache_stack.append(key) + self._logger.debug( + f"Added model {key} (Type: {model.__class__.__name__}, Wrap mode: {wrapped_model.__class__.__name__}, Model size: {size/MB:.2f}MB)" + ) - def get( - self, - key: str, - stats_name: Optional[str] = None, - ) -> CacheRecord: + def get(self, key: str, stats_name: Optional[str] = None) -> CacheRecord: """Retrieve a model from the cache. :param key: Model key @@ -174,6 +168,7 @@ def get( else: if self.stats: self.stats.misses += 1 + self._logger.debug(f"Cache miss: {key}") raise IndexError(f"The model with key {key} is not in the cache.") cache_entry = self._cached_models[key] @@ -181,37 +176,44 @@ def get( # more stats if self.stats: stats_name = stats_name or key - self.stats.cache_size = int(self._max_cache_size * GB) - self.stats.high_watermark = max(self.stats.high_watermark, self._get_cache_size()) + self.stats.high_watermark = max(self.stats.high_watermark, self._get_ram_in_use()) self.stats.in_cache = len(self._cached_models) self.stats.loaded_model_sizes[stats_name] = max( - self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.size + self.stats.loaded_model_sizes.get(stats_name, 0), cache_entry.cached_model.total_bytes() ) - # this moves the entry to the top (right end) of the stack + # This moves the entry to the top (right end) of the stack. self._cache_stack = [k for k in self._cache_stack if k != key] self._cache_stack.append(key) + self._logger.debug(f"Cache hit: {key} (Type: {cache_entry.cached_model.model.__class__.__name__})") return cache_entry def lock(self, cache_entry: CacheRecord) -> None: """Lock a model for use and move it into VRAM.""" if cache_entry.key not in self._cached_models: self._logger.info( - f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has " - "already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal " - "in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)." + f"Locking model cache entry {cache_entry.key} " + f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from " + "the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code " + "(See https://github.com/invoke-ai/InvokeAI/issues/7513)." ) # cache_entry = self._cached_models[key] cache_entry.lock() + self._logger.debug( + f"Locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})" + ) + + if self._execution_device.type == "cpu": + # Models don't need to be loaded into VRAM if we're running on CPU. + return + try: - if self._lazy_offloading: - self._offload_unlocked_models(cache_entry.size) - self._move_model_to_device(cache_entry, self._execution_device) - cache_entry.loaded = True - self._logger.debug(f"Locking {cache_entry.key} in {self._execution_device}") - self._print_cuda_stats() + self._load_locked_model(cache_entry) + self._logger.debug( + f"Finished locking model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})" + ) except torch.cuda.OutOfMemoryError: self._logger.warning("Insufficient GPU memory to load model. Aborting") cache_entry.unlock() @@ -220,201 +222,258 @@ def lock(self, cache_entry: CacheRecord) -> None: cache_entry.unlock() raise + self._log_cache_state() + def unlock(self, cache_entry: CacheRecord) -> None: """Unlock a model.""" if cache_entry.key not in self._cached_models: self._logger.info( - f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has " - "already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal " - "in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)." + f"Unlocking model cache entry {cache_entry.key} " + f"(Type: {cache_entry.cached_model.model.__class__.__name__}), but it has already been dropped from " + "the RAM cache. This is a sign that the model loading order is non-optimal in the invocation code " + "(See https://github.com/invoke-ai/InvokeAI/issues/7513)." ) # cache_entry = self._cached_models[key] cache_entry.unlock() - if not self._lazy_offloading: - self._offload_unlocked_models(0) - self._print_cuda_stats() + self._logger.debug( + f"Unlocked model {cache_entry.key} (Type: {cache_entry.cached_model.model.__class__.__name__})" + ) + + def _load_locked_model(self, cache_entry: CacheRecord) -> None: + """Helper function for self.lock(). Loads a locked model into VRAM.""" + start_time = time.time() + vram_available = self._get_vram_available() + + # Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into + # VRAM. + model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes() + model_total_bytes = cache_entry.cached_model.total_bytes() + model_vram_needed = model_total_bytes - model_cur_vram_bytes + + # The amount of VRAM that must be freed to make room for model_vram_needed. + vram_bytes_to_free = max(0, model_vram_needed - vram_available) + + self._logger.debug( + f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}" + ) + + # Make room for the model in VRAM. + # 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully. + # 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as + # possible. + vram_bytes_freed = self._offload_unlocked_models(vram_bytes_to_free) + self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB") + + # Check the updated vram_available after offloading. + vram_available = self._get_vram_available() + self._logger.debug( + f"After unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}" + ) + + # Move as much of the model as possible into VRAM. + # For testing, only allow 10% of the model to be loaded into VRAM. + # vram_available = int(model_vram_needed * 0.1) + model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available) + + model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes() + vram_available = self._get_vram_available() + loaded_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0 + self._logger.info( + f"Loaded model '{cache_entry.key}' ({cache_entry.cached_model.model.__class__.__name__}) onto " + f"{self._execution_device.type} device in {(time.time() - start_time):.2f}s. " + f"Total model size: {model_total_bytes/MB:.2f}MB, " + f"VRAM: {model_cur_vram_bytes/MB:.2f}MB ({loaded_percent:.1%})" + ) + self._logger.debug(f"Loaded model onto execution device: model_bytes_loaded={(model_bytes_loaded/MB):.2f}MB, ") + self._logger.debug( + f"After loading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}" + ) + + def _move_model_to_vram(self, cache_entry: CacheRecord, vram_available: int) -> int: + try: + if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad): + return cache_entry.cached_model.partial_load_to_vram(vram_available) + elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore + # Partial load is not supported, so we have no choice but to try and fit it all into VRAM. + return cache_entry.cached_model.full_load_to_vram() + else: + raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}") + except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + self._logger.warning("Insufficient GPU memory to load model. Aborting") + # If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely. + self._delete_cache_entry(cache_entry) + raise - def _get_cache_size(self) -> int: - """Get the total size of the models currently cached.""" - total = 0 - for cache_record in self._cached_models.values(): - total += cache_record.size - return total + def _move_model_to_ram(self, cache_entry: CacheRecord, vram_bytes_to_free: int) -> int: + try: + if isinstance(cache_entry.cached_model, CachedModelWithPartialLoad): + return cache_entry.cached_model.partial_unload_from_vram(vram_bytes_to_free) + elif isinstance(cache_entry.cached_model, CachedModelOnlyFullLoad): # type: ignore + return cache_entry.cached_model.full_unload_from_vram() + else: + raise ValueError(f"Unsupported cached model type: {type(cache_entry.cached_model)}") + except Exception: + # If an exception occurs, the model could be left in a bad state, so we delete it from the cache entirely. + self._delete_cache_entry(cache_entry) + raise + + def _get_vram_available(self) -> int: + """Calculate the amount of additional VRAM available for the cache to use.""" + vram_total_available_to_cache = int(self._max_vram_cache_size_gb * GB) + return vram_total_available_to_cache - self._get_vram_in_use() + + def _get_vram_in_use(self) -> int: + """Get the amount of VRAM currently in use by the cache.""" + return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values()) + + def _get_ram_available(self) -> int: + """Get the amount of RAM available for the cache to use, while keeping memory pressure under control.""" + + ram_total_available_to_cache = int(self._max_ram_cache_size_gb * GB) + return ram_total_available_to_cache - self._get_ram_in_use() + + def _get_ram_in_use(self) -> int: + """Get the amount of RAM currently in use.""" + return sum(ce.cached_model.total_bytes() for ce in self._cached_models.values()) def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]: if self._log_memory_usage: return MemorySnapshot.capture() return None - def _make_cache_key(self, model_key: str, submodel_type: Optional[SubModelType] = None) -> str: - if submodel_type: - return f"{model_key}:{submodel_type.value}" - else: - return model_key + def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, vram_available: int) -> str: + """Helper function for preparing a VRAM state log string.""" + model_cur_vram_bytes_percent = model_cur_vram_bytes / model_total_bytes if model_total_bytes > 0 else 0 + return ( + f"model_total={model_total_bytes/MB:.0f} MB, " + + f"model_vram={model_cur_vram_bytes/MB:.0f} MB ({model_cur_vram_bytes_percent:.1%} %), " + # + f"vram_total={int(self._max_vram_cache_size * GB)/MB:.0f} MB, " + + f"vram_available={(vram_available/MB):.0f} MB, " + ) - def _offload_unlocked_models(self, size_required: int) -> None: - """Offload models from the execution_device to make room for size_required. + def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int: + """Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are + offloaded. Of course, locked models are not offloaded. - :param size_required: The amount of space to clear in the execution_device cache, in bytes. + Returns: + int: The number of bytes freed. """ - reserved = self._max_vram_cache_size * GB - vram_in_use = torch.cuda.memory_allocated() + size_required - self._logger.debug(f"{(vram_in_use/GB):.2f}GB VRAM needed for models; max allowed={(reserved/GB):.2f}GB") - for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): - if vram_in_use <= reserved: + self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.") + vram_bytes_freed = 0 + # TODO(ryand): Give more thought to the offloading policy used here. + cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes()) + for cache_entry in cache_entries_increasing_size: + if vram_bytes_freed >= vram_bytes_to_free: break - if not cache_entry.loaded: + if cache_entry.is_locked: continue - if not cache_entry.locked: - self._move_model_to_device(cache_entry, self._storage_device) - cache_entry.loaded = False - vram_in_use = torch.cuda.memory_allocated() + size_required + + cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free - vram_bytes_freed) + if cache_entry_bytes_freed > 0: self._logger.debug( - f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GB):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GB):.2f}GB" + f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB." ) + vram_bytes_freed += cache_entry_bytes_freed TorchDevice.empty_cache() + return vram_bytes_freed - def _move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None: - """Move model into the indicated device. - - :param cache_entry: The CacheRecord for the model - :param target_device: The torch.device to move the model into - - May raise a torch.cuda.OutOfMemoryError - """ - self._logger.debug(f"Called to move {cache_entry.key} to {target_device}") - source_device = cache_entry.device - - # Note: We compare device types only so that 'cuda' == 'cuda:0'. - # This would need to be revised to support multi-GPU. - if torch.device(source_device).type == torch.device(target_device).type: - return - - # Some models don't have a `to` method, in which case they run in RAM/CPU. - if not hasattr(cache_entry.model, "to"): + def _log_cache_state(self, title: str = "Model cache state:", include_entry_details: bool = True): + if self._logger.getEffectiveLevel() > logging.DEBUG: + # Short circuit if the logger is not set to debug. Some of the data lookups could take a non-negligible + # amount of time. return - # This roundabout method for moving the model around is done to avoid - # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. - # When moving to VRAM, we copy (not move) each element of the state dict from - # RAM to a new state dict in VRAM, and then inject it into the model. - # This operation is slightly faster than running `to()` on the whole model. - # - # When the model needs to be removed from VRAM we simply delete the copy - # of the state dict in VRAM, and reinject the state dict that is cached - # in RAM into the model. So this operation is very fast. - start_model_to_time = time.time() - snapshot_before = self._capture_memory_snapshot() - - try: - if cache_entry.state_dict is not None: - assert hasattr(cache_entry.model, "load_state_dict") - if target_device == self._storage_device: - cache_entry.model.load_state_dict(cache_entry.state_dict, assign=True) - else: - new_dict: Dict[str, torch.Tensor] = {} - for k, v in cache_entry.state_dict.items(): - new_dict[k] = v.to(target_device, copy=True) - cache_entry.model.load_state_dict(new_dict, assign=True) - cache_entry.model.to(target_device) - cache_entry.device = target_device - except Exception as e: # blow away cache entry - self._delete_cache_entry(cache_entry) - raise e - - snapshot_after = self._capture_memory_snapshot() - end_model_to_time = time.time() - self._logger.debug( - f"Moved model '{cache_entry.key}' from {source_device} to" - f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s." - f"Estimated model size: {(cache_entry.size/GB):.3f} GB." - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" + log = f"{title}\n" + + log_format = " {:<30} Limit: {:>7.1f} MB, Used: {:>7.1f} MB ({:>5.1%}), Available: {:>7.1f} MB ({:>5.1%})\n" + + ram_in_use_bytes = self._get_ram_in_use() + ram_available_bytes = self._get_ram_available() + ram_size_bytes = ram_in_use_bytes + ram_available_bytes + ram_in_use_bytes_percent = ram_in_use_bytes / ram_size_bytes if ram_size_bytes > 0 else 0 + ram_available_bytes_percent = ram_available_bytes / ram_size_bytes if ram_size_bytes > 0 else 0 + log += log_format.format( + f"Storage Device ({self._storage_device.type})", + ram_size_bytes / MB, + ram_in_use_bytes / MB, + ram_in_use_bytes_percent, + ram_available_bytes / MB, + ram_available_bytes_percent, ) - if ( - snapshot_before is not None - and snapshot_after is not None - and snapshot_before.vram is not None - and snapshot_after.vram is not None - ): - vram_change = abs(snapshot_before.vram - snapshot_after.vram) - - # If the estimated model size does not match the change in VRAM, log a warning. - if not math.isclose( - vram_change, - cache_entry.size, - rel_tol=0.1, - abs_tol=10 * MB, - ): - self._logger.debug( - f"Moving model '{cache_entry.key}' from {source_device} to" - f" {target_device} caused an unexpected change in VRAM usage. The model's" - " estimated size may be incorrect. Estimated model size:" - f" {(cache_entry.size/GB):.3f} GB.\n" - f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" - ) + if self._execution_device.type != "cpu": + vram_in_use_bytes = self._get_vram_in_use() + vram_available_bytes = self._get_vram_available() + vram_size_bytes = vram_in_use_bytes + vram_available_bytes + vram_in_use_bytes_percent = vram_in_use_bytes / vram_size_bytes if vram_size_bytes > 0 else 0 + vram_available_bytes_percent = vram_available_bytes / vram_size_bytes if vram_size_bytes > 0 else 0 + log += log_format.format( + f"Compute Device ({self._execution_device.type})", + vram_size_bytes / MB, + vram_in_use_bytes / MB, + vram_in_use_bytes_percent, + vram_available_bytes / MB, + vram_available_bytes_percent, + ) - def _print_cuda_stats(self) -> None: - """Log CUDA diagnostics.""" - vram = "%4.2fG" % (torch.cuda.memory_allocated() / GB) - ram = "%4.2fG" % (self._get_cache_size() / GB) - - in_ram_models = 0 - in_vram_models = 0 - locked_in_vram_models = 0 - for cache_record in self._cached_models.values(): - if hasattr(cache_record.model, "device"): - if cache_record.model.device == self._storage_device: - in_ram_models += 1 - else: - in_vram_models += 1 - if cache_record.locked: - locked_in_vram_models += 1 + if torch.cuda.is_available(): + log += " {:<30} {:.1f} MB\n".format("CUDA Memory Allocated:", torch.cuda.memory_allocated() / MB) + log += " {:<30} {}\n".format("Total models:", len(self._cached_models)) - self._logger.debug( - f"Current VRAM/RAM usage: {vram}/{ram}; models_in_ram/models_in_vram(locked) =" - f" {in_ram_models}/{in_vram_models}({locked_in_vram_models})" + if include_entry_details and len(self._cached_models) > 0: + log += " Models:\n" + log_format = ( + " {:<80} total={:>7.1f} MB, vram={:>7.1f} MB ({:>5.1%}), ram={:>7.1f} MB ({:>5.1%}), locked={}\n" + ) + for cache_record in self._cached_models.values(): + total_bytes = cache_record.cached_model.total_bytes() + cur_vram_bytes = cache_record.cached_model.cur_vram_bytes() + cur_vram_bytes_percent = cur_vram_bytes / total_bytes if total_bytes > 0 else 0 + cur_ram_bytes = total_bytes - cur_vram_bytes + cur_ram_bytes_percent = cur_ram_bytes / total_bytes if total_bytes > 0 else 0 + + log += log_format.format( + f"{cache_record.key} ({cache_record.cached_model.model.__class__.__name__}):", + total_bytes / MB, + cur_vram_bytes / MB, + cur_vram_bytes_percent, + cur_ram_bytes / MB, + cur_ram_bytes_percent, + cache_record.is_locked, ) - def make_room(self, size: int) -> None: + self._logger.debug(log) + + def make_room(self, bytes_needed: int) -> None: """Make enough room in the cache to accommodate a new model of indicated size. Note: This function deletes all of the cache's internal references to a model in order to free it. If there are external references to the model, there's nothing that the cache can do about it, and those models will not be garbage-collected. """ - bytes_needed = size - maximum_size = self._max_cache_size * GB # stored in GB, convert to bytes - current_size = self._get_cache_size() - - if current_size + bytes_needed > maximum_size: - self._logger.debug( - f"Max cache size exceeded: {(current_size/GB):.2f}/{self.max_cache_size:.2f} GB, need an additional" - f" {(bytes_needed/GB):.2f} GB" - ) + self._logger.debug(f"Making room for {bytes_needed/MB:.2f}MB of RAM.") + self._log_cache_state(title="Before dropping models:") - self._logger.debug(f"Before making_room: cached_models={len(self._cached_models)}") + ram_bytes_available = self._get_ram_available() + ram_bytes_to_free = max(0, bytes_needed - ram_bytes_available) + ram_bytes_freed = 0 pos = 0 models_cleared = 0 - while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): + while ram_bytes_freed < ram_bytes_to_free and pos < len(self._cache_stack): model_key = self._cache_stack[pos] cache_entry = self._cached_models[model_key] - device = cache_entry.model.device if hasattr(cache_entry.model, "device") else None - self._logger.debug( - f"Model: {model_key}, locks: {cache_entry._locks}, device: {device}, loaded: {cache_entry.loaded}" - ) - if not cache_entry.locked: + if not cache_entry.is_locked: + ram_bytes_freed += cache_entry.cached_model.total_bytes() self._logger.debug( - f"Removing {model_key} from RAM cache to free at least {(size/GB):.2f} GB (-{(cache_entry.size/GB):.2f} GB)" + f"Dropping {model_key} from RAM cache to free {(cache_entry.cached_model.total_bytes()/MB):.2f}MB." ) - current_size -= cache_entry.size - models_cleared += 1 self._delete_cache_entry(cache_entry) del cache_entry - + models_cleared += 1 else: pos += 1 @@ -435,8 +494,10 @@ def make_room(self, size: int) -> None: gc.collect() TorchDevice.empty_cache() - self._logger.debug(f"After making room: cached_models={len(self._cached_models)}") + self._logger.debug(f"Dropped {models_cleared} models to free {ram_bytes_freed/MB:.2f}MB of RAM.") + self._log_cache_state(title="After dropping models:") def _delete_cache_entry(self, cache_entry: CacheRecord) -> None: - self._cache_stack.remove(cache_entry.key) - del self._cached_models[cache_entry.key] + """Delete cache_entry from the cache if it exists. No exception is thrown if it doesn't exist.""" + self._cache_stack = [key for key in self._cache_stack if key != cache_entry.key] + self._cached_models.pop(cache_entry.key, None) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 61d77dac129..f396a93d2db 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -92,8 +92,8 @@ def mm2_download_queue(mm2_session: Session) -> DownloadQueueServiceBase: def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: ram_cache = ModelCache( logger=InvokeAILogger.get_logger(), - max_cache_size=mm2_app_config.ram, - max_vram_cache_size=mm2_app_config.vram, + max_ram_cache_size_gb=mm2_app_config.ram, + max_vram_cache_size_gb=mm2_app_config.vram, ) return ModelLoadService( app_config=mm2_app_config, From d0bfa019be04219136291cd4f397c9457a4f0caf Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 19:52:10 +0000 Subject: [PATCH 02/10] Add 'enable_partial_loading' config flag. --- invokeai/app/services/config/config_default.py | 2 ++ invokeai/app/services/model_manager/model_manager_default.py | 1 + .../backend/model_manager/load/model_cache/model_cache.py | 4 +++- tests/backend/model_manager/model_manager_fixtures.py | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 4c0333a2605..52653de0f4c 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -107,6 +107,7 @@ class InvokeAIAppConfig(BaseSettings): vram: Amount of VRAM reserved for model storage (GB). lazy_offload: Keep models in VRAM until their space is needed. log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. + enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned. device: Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.
Valid values: `auto`, `cpu`, `cuda`, `cuda:1`, `mps` precision: Floating point precision. `float16` will consume half the memory of `float32` but produce slightly lower-quality images. The `auto` setting will guess the proper precision based on your video card and operating system.
Valid values: `auto`, `float16`, `bfloat16`, `float32` sequential_guidance: Whether to calculate guidance in serial instead of in parallel, lowering memory requirements. @@ -178,6 +179,7 @@ class InvokeAIAppConfig(BaseSettings): vram: float = Field(default=DEFAULT_VRAM_CACHE, ge=0, description="Amount of VRAM reserved for model storage (GB).") lazy_offload: bool = Field(default=True, description="Keep models in VRAM until their space is needed.") log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") + enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. Partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. If enabling this setting, make sure that your ram and vram cache limits are properly tuned.") # DEVICE device: DEVICE = Field(default="auto", description="Preferred execution device. `auto` will choose the device depending on the hardware platform and the installed torch capabilities.") diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index bdd1f5da437..c7bcd43d7a7 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -84,6 +84,7 @@ def build_model_manager( ram_cache = ModelCache( max_ram_cache_size_gb=app_config.ram, max_vram_cache_size_gb=app_config.vram, + enable_partial_loading=app_config.enable_partial_loading, logger=logger, execution_device=execution_device or TorchDevice.choose_torch_device(), ) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index ecf3ffa6234..377f4910b4c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -76,6 +76,7 @@ def __init__( self, max_ram_cache_size_gb: float, max_vram_cache_size_gb: float, + enable_partial_loading: bool, execution_device: torch.device | str = "cuda", storage_device: torch.device | str = "cpu", log_memory_usage: bool = False, @@ -102,6 +103,7 @@ def __init__( self._max_ram_cache_size_gb = max_ram_cache_size_gb self._max_vram_cache_size_gb = max_vram_cache_size_gb + self._enable_partial_loading = enable_partial_loading self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) self._log_memory_usage = log_memory_usage @@ -142,7 +144,7 @@ def put(self, key: str, model: AnyModel) -> None: running_with_cuda = self._execution_device.type == "cuda" # Wrap model. - if isinstance(model, torch.nn.Module) and running_with_cuda: + if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading: wrapped_model = CachedModelWithPartialLoad(model, self._execution_device) else: wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index f396a93d2db..4449bbaf62f 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -94,6 +94,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: logger=InvokeAILogger.get_logger(), max_ram_cache_size_gb=mm2_app_config.ram, max_vram_cache_size_gb=mm2_app_config.vram, + enable_partial_loading=mm2_app_config.enable_partial_loading, ) return ModelLoadService( app_config=mm2_app_config, From ceb2498a67816c866826a730bfbda15588c811de Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 19:54:51 +0000 Subject: [PATCH 03/10] Add log prefix to model cache logs. --- .../model_manager/load/model_cache/model_cache.py | 5 ++++- invokeai/backend/util/prefix_logger_adapter.py | 12 ++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/util/prefix_logger_adapter.py diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index 377f4910b4c..f1d3f8cf9ef 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -22,6 +22,7 @@ from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger +from invokeai.backend.util.prefix_logger_adapter import PrefixedLoggerAdapter # Size of a GB in bytes. GB = 2**30 @@ -105,7 +106,9 @@ def __init__( self._max_vram_cache_size_gb = max_vram_cache_size_gb self._enable_partial_loading = enable_partial_loading - self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__) + self._logger = PrefixedLoggerAdapter( + logger or InvokeAILogger.get_logger(self.__class__.__name__), "MODEL CACHE" + ) self._log_memory_usage = log_memory_usage self._stats: Optional[CacheStats] = None diff --git a/invokeai/backend/util/prefix_logger_adapter.py b/invokeai/backend/util/prefix_logger_adapter.py new file mode 100644 index 00000000000..94f0478c95d --- /dev/null +++ b/invokeai/backend/util/prefix_logger_adapter.py @@ -0,0 +1,12 @@ +import logging +from typing import Any, MutableMapping + + +# Issue with type hints related to LoggerAdapter: https://github.com/python/typeshed/issues/7855 +class PrefixedLoggerAdapter(logging.LoggerAdapter): # type: ignore + def __init__(self, logger: logging.Logger, prefix: str): + super().__init__(logger, {}) + self.prefix = prefix + + def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> tuple[str, MutableMapping[str, Any]]: + return f"[{self.prefix}] {msg}", kwargs From 7127040c3ae95e494075b7b3d1cd4a403c73adad Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 20:26:49 +0000 Subject: [PATCH 04/10] Remove unused function set_nested_attr(...). --- .../cached_model/cached_model_with_partial_load.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index a5e1e3d5398..543a739475a 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -7,18 +7,6 @@ from invokeai.backend.util.logging import InvokeAILogger -def set_nested_attr(obj: object, attr: str, value: object): - """A helper function that extends setattr() to support nested attributes. - - Example: - set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight) - """ - attrs = attr.split(".") - for attr in attrs[:-1]: - obj = getattr(obj, attr) - setattr(obj, attrs[-1], value) - - class CachedModelWithPartialLoad: """A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device. From 402dd840a1eb866d57fdf37abe2ee9130c705aef Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 17:53:14 -0500 Subject: [PATCH 05/10] Add seed to flaky unit test. --- .../custom_modules/test_all_custom_modules.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 97062772341..875b95da071 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -358,6 +358,8 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest): patches, input = patch_under_test + torch.manual_seed(0) + # Build the base layer under test. layer = torch.nn.Linear(32, 64) From 1b7bb70bde84232f0fe92e675f5e7e529de03c46 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 17:57:04 -0500 Subject: [PATCH 06/10] Improve handling of cases when application code modifies the size of a model after registering it with the model cache. --- .../cached_model_with_partial_load.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 543a739475a..cecf7fb20d9 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -21,9 +21,14 @@ def __init__(self, model: torch.nn.Module, compute_device: torch.device): # A CPU read-only copy of the model's state dict. self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict() - # TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting). - # Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes. - self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values()) + # A dictionary of the size of each tensor in the state dict. + # HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for + # consistency in case the application code has modified the model's size (e.g. by casting to a different + # precision). Of course, this means that we are making model cache load/unload decisions based on model size + # data that may not be fully accurate. + self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()} + + self._total_bytes = sum(self._state_dict_bytes.values()) self._cur_vram_bytes: int | None = None self._modules_that_support_autocast = self._find_modules_that_support_autocast() @@ -79,7 +84,9 @@ def cur_vram_bytes(self) -> int: if self._cur_vram_bytes is None: cur_state_dict = self._model.state_dict() self._cur_vram_bytes = sum( - calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type + self._state_dict_bytes[k] + for k, v in cur_state_dict.items() + if v.device.type == self._compute_device.type ) return self._cur_vram_bytes @@ -111,7 +118,7 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: if param.device.type == self._compute_device.type: continue - param_size = calc_tensor_size(param) + param_size = self._state_dict_bytes[key] cur_state_dict[key] = param.to(self._compute_device, copy=True) vram_bytes_loaded += param_size @@ -128,7 +135,7 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: if param.device.type == self._compute_device.type: continue - param_size = calc_tensor_size(param) + param_size = self._state_dict_bytes[key] if vram_bytes_loaded + param_size > vram_bytes_to_load: # TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really # worth continuing to search for a smaller parameter that would fit? @@ -149,7 +156,6 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: if fully_loaded: self._set_autocast_enabled_in_all_modules(False) - # TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync. else: self._set_autocast_enabled_in_all_modules(True) @@ -178,7 +184,7 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int: continue cur_state_dict[key] = self._cpu_state_dict[key] - vram_bytes_freed += calc_tensor_size(param) + vram_bytes_freed += self._state_dict_bytes[key] if vram_bytes_freed > 0: self._model.load_state_dict(cur_state_dict, assign=True) From bcd29c5d740df097e970c3dc9469ba4f5aea0610 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 19:17:56 -0500 Subject: [PATCH 07/10] Remove all cases where we check the 'model.device'. This is no longer trustworthy now that partial loading is permitted. --- invokeai/app/invocations/compel.py | 2 ++ invokeai/app/invocations/latents_to_image.py | 2 +- invokeai/app/invocations/sd3_latents_to_image.py | 2 +- invokeai/app/invocations/sd3_text_encoder.py | 5 +++-- invokeai/backend/flux/modules/conditioner.py | 4 +++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index b535254cfd4..f3686ae6488 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -105,6 +105,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: textual_inversion_manager=ti_manager, dtype_for_device_getter=TorchDevice.choose_torch_dtype, truncate_long_prompts=False, + device=TorchDevice.choose_torch_device(), ) conjunction = Compel.parse_prompt_string(self.prompt) @@ -207,6 +208,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: truncate_long_prompts=False, # TODO: returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, # TODO: clip skip requires_pooled=get_pooled, + device=TorchDevice.choose_torch_device(), ) conjunction = Compel.parse_prompt_string(prompt) diff --git a/invokeai/app/invocations/latents_to_image.py b/invokeai/app/invocations/latents_to_image.py index 41cc6bfbd17..45e06a3f2ad 100644 --- a/invokeai/app/invocations/latents_to_image.py +++ b/invokeai/app/invocations/latents_to_image.py @@ -62,7 +62,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: context.util.signal_progress("Running VAE decoder") assert isinstance(vae, (AutoencoderKL, AutoencoderTiny)) - latents = latents.to(vae.device) + latents = latents.to(TorchDevice.choose_torch_device()) if self.fp32: vae.to(dtype=torch.float32) diff --git a/invokeai/app/invocations/sd3_latents_to_image.py b/invokeai/app/invocations/sd3_latents_to_image.py index 184759b2f02..55cbddcc51e 100644 --- a/invokeai/app/invocations/sd3_latents_to_image.py +++ b/invokeai/app/invocations/sd3_latents_to_image.py @@ -49,7 +49,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput: with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae: context.util.signal_progress("Running VAE") assert isinstance(vae, (AutoencoderKL)) - latents = latents.to(vae.device) + latents = latents.to(TorchDevice.choose_torch_device()) vae.disable_tiling() diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 6569fa0a762..0f83ca32188 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -21,6 +21,7 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo +from invokeai.backend.util.devices import TorchDevice # The SD3 T5 Max Sequence Length set based on the default in diffusers. SD3_T5_MAX_SEQ_LEN = 256 @@ -120,7 +121,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens f" {max_seq_len} tokens: {removed_text}" ) - prompt_embeds = t5_text_encoder(text_input_ids.to(t5_text_encoder.device))[0] + prompt_embeds = t5_text_encoder(text_input_ids.to(TorchDevice.choose_torch_device()))[0] assert isinstance(prompt_embeds, torch.Tensor) return prompt_embeds @@ -185,7 +186,7 @@ def _clip_encode( f" {tokenizer_max_length} tokens: {removed_text}" ) prompt_embeds = clip_text_encoder( - input_ids=text_input_ids.to(clip_text_encoder.device), output_hidden_states=True + input_ids=text_input_ids.to(TorchDevice.choose_torch_device()), output_hidden_states=True ) pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] diff --git a/invokeai/backend/flux/modules/conditioner.py b/invokeai/backend/flux/modules/conditioner.py index de6d8256c4f..c03e877e2db 100644 --- a/invokeai/backend/flux/modules/conditioner.py +++ b/invokeai/backend/flux/modules/conditioner.py @@ -3,6 +3,8 @@ from torch import Tensor, nn from transformers import PreTrainedModel, PreTrainedTokenizer +from invokeai.backend.util.devices import TorchDevice + class HFEncoder(nn.Module): def __init__(self, encoder: PreTrainedModel, tokenizer: PreTrainedTokenizer, is_clip: bool, max_length: int): @@ -26,7 +28,7 @@ def forward(self, text: list[str]) -> Tensor: ) outputs = self.hf_module( - input_ids=batch_encoding["input_ids"].to(self.hf_module.device), + input_ids=batch_encoding["input_ids"].to(TorchDevice.choose_torch_device()), attention_mask=None, output_hidden_states=False, ) From 2619ef53cae52b969cb6390c9f3c38fe0ba845d6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 31 Dec 2024 17:08:45 +0000 Subject: [PATCH 08/10] Handle device casting in ia2_layer.py. --- invokeai/backend/patches/layers/ia3_layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/patches/layers/ia3_layer.py b/invokeai/backend/patches/layers/ia3_layer.py index fa5f0c1ca45..21c84669836 100644 --- a/invokeai/backend/patches/layers/ia3_layer.py +++ b/invokeai/backend/patches/layers/ia3_layer.py @@ -2,6 +2,7 @@ import torch +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase @@ -50,7 +51,7 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: weight = self.weight if not self.on_input: weight = weight.reshape(-1, 1) - return orig_weight * weight + return cast_to_device(orig_weight, weight.device) * weight def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): super().to(device, dtype) From e5180c4e6b54ac5be96b1498ff81218f382b7d58 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 31 Dec 2024 18:55:27 +0000 Subject: [PATCH 09/10] Add get_effective_device(...) utility to aid in determining the effective device of models that are partially loaded. --- invokeai/backend/image_util/hed.py | 5 +++-- .../backend/image_util/infill_methods/lama.py | 3 ++- invokeai/backend/image_util/lineart.py | 5 +++-- invokeai/backend/image_util/lineart_anime.py | 5 +++-- invokeai/backend/image_util/mlsd/utils.py | 6 ++++-- .../backend/image_util/normal_bae/__init__.py | 3 ++- invokeai/backend/image_util/pidi/__init__.py | 3 ++- .../model_manager/load/model_cache/utils.py | 20 +++++++++++++++++++ 8 files changed, 39 insertions(+), 11 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_cache/utils.py diff --git a/invokeai/backend/image_util/hed.py b/invokeai/backend/image_util/hed.py index ec12c26b2e3..a2d3449f650 100644 --- a/invokeai/backend/image_util/hed.py +++ b/invokeai/backend/image_util/hed.py @@ -18,6 +18,7 @@ resize_image_to_resolution, safe_step, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class DoubleConvBlock(torch.nn.Module): @@ -109,7 +110,7 @@ def run( Returns: The detected edges. """ - device = next(iter(self.network.parameters())).device + device = get_effective_device(self.network) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) np_image = resize_image_to_resolution(np_image, detect_resolution) @@ -183,7 +184,7 @@ def run(self, image: Image.Image, safe: bool = False, scribble: bool = False) -> The detected edges. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index cd5838d1f2b..faf25e44a49 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -7,6 +7,7 @@ import invokeai.backend.util.logging as logger from invokeai.backend.model_manager.config import AnyModel +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device def norm_img(np_img): @@ -31,7 +32,7 @@ def __call__(self, input_image: Image.Image, *args: Any, **kwds: Any) -> Any: mask = norm_img(mask) mask = (mask > 0) * 1 - device = next(self._model.buffers()).device + device = get_effective_device(self._model) image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) diff --git a/invokeai/backend/image_util/lineart.py b/invokeai/backend/image_util/lineart.py index 8fcca24b0e0..bfef6f6da08 100644 --- a/invokeai/backend/image_util/lineart.py +++ b/invokeai/backend/image_util/lineart.py @@ -17,6 +17,7 @@ pil_to_np, resize_image_to_resolution, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class ResidualBlock(nn.Module): @@ -130,7 +131,7 @@ def run( Returns: The detected lineart. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) @@ -201,7 +202,7 @@ def run(self, image: Image.Image) -> Image.Image: Returns: The detected edges. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/lineart_anime.py b/invokeai/backend/image_util/lineart_anime.py index 09dcb6655e3..fa406cf1d4b 100644 --- a/invokeai/backend/image_util/lineart_anime.py +++ b/invokeai/backend/image_util/lineart_anime.py @@ -19,6 +19,7 @@ pil_to_np, resize_image_to_resolution, ) +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class UnetGenerator(nn.Module): @@ -171,7 +172,7 @@ def run(self, input_image: Image.Image, detect_resolution: int = 512, image_reso Returns: The detected lineart. """ - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(input_image) np_image = normalize_image_channel_count(np_image) @@ -239,7 +240,7 @@ def to(self, device: torch.device): def run(self, image: Image.Image) -> Image.Image: """Processes an image and returns the detected edges.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) diff --git a/invokeai/backend/image_util/mlsd/utils.py b/invokeai/backend/image_util/mlsd/utils.py index dbe9a98d09e..dbadce01a4f 100644 --- a/invokeai/backend/image_util/mlsd/utils.py +++ b/invokeai/backend/image_util/mlsd/utils.py @@ -14,6 +14,8 @@ import torch from torch.nn import functional as F +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device + def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): ''' @@ -49,7 +51,7 @@ def pred_lines(image, model, dist_thr=20.0): h, w, _ = image.shape - device = next(iter(model.parameters())).device + device = get_effective_device(model) h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), @@ -108,7 +110,7 @@ def pred_squares(image, ''' h, w, _ = image.shape original_shape = [h, w] - device = next(iter(model.parameters())).device + device = get_effective_device(model) resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), np.ones([input_shape[0], input_shape[1], 1])], axis=-1) diff --git a/invokeai/backend/image_util/normal_bae/__init__.py b/invokeai/backend/image_util/normal_bae/__init__.py index d0b1339113e..5ad221ecd4a 100644 --- a/invokeai/backend/image_util/normal_bae/__init__.py +++ b/invokeai/backend/image_util/normal_bae/__init__.py @@ -13,6 +13,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET from invokeai.backend.image_util.util import np_to_pil, pil_to_np, resize_to_multiple +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class NormalMapDetector: @@ -64,7 +65,7 @@ def to(self, device: torch.device): def run(self, image: Image.Image): """Processes an image and returns the detected normal map.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_image = pil_to_np(image) height, width, _channels = np_image.shape diff --git a/invokeai/backend/image_util/pidi/__init__.py b/invokeai/backend/image_util/pidi/__init__.py index 8673b219140..63df7b6058e 100644 --- a/invokeai/backend/image_util/pidi/__init__.py +++ b/invokeai/backend/image_util/pidi/__init__.py @@ -11,6 +11,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet, pidinet from invokeai.backend.image_util.util import nms, normalize_image_channel_count, np_to_pil, pil_to_np, safe_step +from invokeai.backend.model_manager.load.model_cache.utils import get_effective_device class PIDINetDetector: @@ -45,7 +46,7 @@ def run( ) -> Image.Image: """Processes an image and returns the detected edges.""" - device = next(iter(self.model.parameters())).device + device = get_effective_device(self.model) np_img = pil_to_np(image) np_img = normalize_image_channel_count(np_img) diff --git a/invokeai/backend/model_manager/load/model_cache/utils.py b/invokeai/backend/model_manager/load/model_cache/utils.py new file mode 100644 index 00000000000..2b581990c69 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/utils.py @@ -0,0 +1,20 @@ +import itertools + +import torch + + +def get_effective_device(model: torch.nn.Module) -> torch.device: + """A utility to infer the 'effective' device of a model. + + This utility handles the case where a model is partially loaded onto the GPU, so is safer than just calling: + `next(iter(model.parameters())).device`. + + In the worst case, this utility has to check all model parameters, so if you already know the intended model device, + then it is better to avoid calling this function. + """ + # If all parameters are on the CPU, return the CPU device. Otherwise, return the first non-CPU device. + for p in itertools.chain(model.parameters(), model.buffers()): + if p.device.type != "cpu": + return p.device + + return torch.device("cpu") From 6a9de1fcf3e1c50b5dbdc537934708b75994e2c8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 6 Jan 2025 20:38:17 +0000 Subject: [PATCH 10/10] Change definition of VRAM in use for the ModelCache from sum of model weights to the total torch.cuda.memory_allocated(). --- .../load/model_cache/model_cache.py | 38 ++++++++++++------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index f1d3f8cf9ef..98462a54c53 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -247,7 +247,6 @@ def unlock(self, cache_entry: CacheRecord) -> None: def _load_locked_model(self, cache_entry: CacheRecord) -> None: """Helper function for self.lock(). Loads a locked model into VRAM.""" start_time = time.time() - vram_available = self._get_vram_available() # Calculate model_vram_needed, the amount of additional VRAM that will be used if we fully load the model into # VRAM. @@ -255,9 +254,7 @@ def _load_locked_model(self, cache_entry: CacheRecord) -> None: model_total_bytes = cache_entry.cached_model.total_bytes() model_vram_needed = model_total_bytes - model_cur_vram_bytes - # The amount of VRAM that must be freed to make room for model_vram_needed. - vram_bytes_to_free = max(0, model_vram_needed - vram_available) - + vram_available = self._get_vram_available() self._logger.debug( f"Before unloading: {self._get_vram_state_str(model_cur_vram_bytes, model_total_bytes, vram_available)}" ) @@ -266,7 +263,7 @@ def _load_locked_model(self, cache_entry: CacheRecord) -> None: # 1. If the model can fit entirely in VRAM, then make enough room for it to be loaded fully. # 2. If the model can't fit fully into VRAM, then unload all other models and load as much of the model as # possible. - vram_bytes_freed = self._offload_unlocked_models(vram_bytes_to_free) + vram_bytes_freed = self._offload_unlocked_models(model_vram_needed) self._logger.debug(f"Unloaded models (if necessary): vram_bytes_freed={(vram_bytes_freed/MB):.2f}MB") # Check the updated vram_available after offloading. @@ -278,7 +275,9 @@ def _load_locked_model(self, cache_entry: CacheRecord) -> None: # Move as much of the model as possible into VRAM. # For testing, only allow 10% of the model to be loaded into VRAM. # vram_available = int(model_vram_needed * 0.1) - model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available) + # We add 1 MB to the available VRAM to account for small errors in memory tracking (e.g. off-by-one). A fully + # loaded model is much faster than a 95% loaded model. + model_bytes_loaded = self._move_model_to_vram(cache_entry, vram_available + MB) model_cur_vram_bytes = cache_entry.cached_model.cur_vram_bytes() vram_available = self._get_vram_available() @@ -330,7 +329,14 @@ def _get_vram_available(self) -> int: def _get_vram_in_use(self) -> int: """Get the amount of VRAM currently in use by the cache.""" - return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values()) + if self._execution_device.type == "cuda": + return torch.cuda.memory_allocated() + elif self._execution_device.type == "mps": + return torch.mps.current_allocated_memory() + else: + raise ValueError(f"Unsupported execution device type: {self._execution_device.type}") + # Alternative definition of VRAM in use: + # return sum(ce.cached_model.cur_vram_bytes() for ce in self._cached_models.values()) def _get_ram_available(self) -> int: """Get the amount of RAM available for the cache to use, while keeping memory pressure under control.""" @@ -357,24 +363,28 @@ def _get_vram_state_str(self, model_cur_vram_bytes: int, model_total_bytes: int, + f"vram_available={(vram_available/MB):.0f} MB, " ) - def _offload_unlocked_models(self, vram_bytes_to_free: int) -> int: - """Offload models from the execution_device until vram_bytes_to_free bytes are freed, or all models are + def _offload_unlocked_models(self, vram_bytes_required: int) -> int: + """Offload models from the execution_device until vram_bytes_required bytes are available, or all models are offloaded. Of course, locked models are not offloaded. Returns: - int: The number of bytes freed. + int: The number of bytes freed based on believed model sizes. The actual change in VRAM may be different. """ - self._logger.debug(f"Offloading unlocked models with goal of freeing {vram_bytes_to_free/MB:.2f}MB of VRAM.") + self._logger.debug( + f"Offloading unlocked models with goal of making room for {vram_bytes_required/MB:.2f}MB of VRAM." + ) vram_bytes_freed = 0 # TODO(ryand): Give more thought to the offloading policy used here. cache_entries_increasing_size = sorted(self._cached_models.values(), key=lambda x: x.cached_model.total_bytes()) for cache_entry in cache_entries_increasing_size: - if vram_bytes_freed >= vram_bytes_to_free: + # We do not fully trust the count of bytes freed, so we check again on each iteration. + vram_available = self._get_vram_available() + vram_bytes_to_free = vram_bytes_required - vram_available + if vram_bytes_to_free <= 0: break if cache_entry.is_locked: continue - - cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free - vram_bytes_freed) + cache_entry_bytes_freed = self._move_model_to_ram(cache_entry, vram_bytes_to_free) if cache_entry_bytes_freed > 0: self._logger.debug( f"Unloaded {cache_entry.key} from VRAM to free {(cache_entry_bytes_freed/MB):.0f} MB."