Skip to content

Commit

Permalink
Allow models to be locked in VRAM, even if they have been dropped fro…
Browse files Browse the repository at this point in the history
…m the RAM cache (related: #7513).
  • Loading branch information
RyanJDick committed Jan 6, 2025
1 parent f4f7415 commit c579a21
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
8 changes: 4 additions & 4 deletions invokeai/backend/model_manager/load/load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache):
self._cache = cache

def __enter__(self) -> AnyModel:
self._cache.lock(self._cache_record.key)
self._cache.lock(self._cache_record)
return self.model

def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._cache.unlock(self._cache_record.key)
self._cache.unlock(self._cache_record)

@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record.key)
self._cache.lock(self._cache_record)
try:
yield (self._cache_record.state_dict, self._cache_record.model)
finally:
self._cache.unlock(self._cache_record.key)
self._cache.unlock(self._cache_record)

@property
def model(self) -> AnyModel:
Expand Down
20 changes: 16 additions & 4 deletions invokeai/backend/model_manager/load/model_cache/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,15 @@ def get(

return cache_entry

def lock(self, key: str) -> None:
def lock(self, cache_entry: CacheRecord) -> None:
"""Lock a model for use and move it into VRAM."""
cache_entry = self._cached_models[key]
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)."
)
# cache_entry = self._cached_models[key]
cache_entry.lock()

try:
Expand All @@ -214,9 +220,15 @@ def lock(self, key: str) -> None:
cache_entry.unlock()
raise

def unlock(self, key: str) -> None:
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
cache_entry = self._cached_models[key]
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)."
)
# cache_entry = self._cached_models[key]
cache_entry.unlock()
if not self._lazy_offloading:
self._offload_unlocked_models(0)
Expand Down

0 comments on commit c579a21

Please sign in to comment.