Skip to content

Commit

Permalink
Partial Loading PR 3.5: Fix pre-mature model drops from the RAM cache (
Browse files Browse the repository at this point in the history
…#7522)

## Summary

This is an unplanned fix between PR3 and PR4 in the sequence of partial
loading (i.e. low-VRAM) PRs. This PR restores the 'Current Workaround'
documented in #7513. In
other words, to work around a flaw in the model cache API, this fix
allows models to be loaded into VRAM _even if_ they have been dropped
from the RAM cache.

This PR also adds an info log each time that this workaround is hit. In
a future PR (#7509), we will eliminate the places in the application
code that are capable of triggering this condition.

## Related Issues / Discussions

- #7492 
- #7494
- #7500 
- #7513

## QA Instructions

- Set RAM cache limit to a small value. E.g. `ram: 4`
- Run FLUX text-to-image with the full T5 encoder, which exceeds 4GB.
This will trigger the error condition.
- Before the fix, this test configuration would cause a `KeyError`.
After the fix, we should see an info-level log explaining that the
condition was hit, but that generation should continue successfully.

## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 7, 2025
2 parents f4f7415 + c579a21 commit 782ee7a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
8 changes: 4 additions & 4 deletions invokeai/backend/model_manager/load/load_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ def __init__(self, cache_record: CacheRecord, cache: ModelCache):
self._cache = cache

def __enter__(self) -> AnyModel:
self._cache.lock(self._cache_record.key)
self._cache.lock(self._cache_record)
return self.model

def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._cache.unlock(self._cache_record.key)
self._cache.unlock(self._cache_record)

@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
self._cache.lock(self._cache_record.key)
self._cache.lock(self._cache_record)
try:
yield (self._cache_record.state_dict, self._cache_record.model)
finally:
self._cache.unlock(self._cache_record.key)
self._cache.unlock(self._cache_record)

@property
def model(self) -> AnyModel:
Expand Down
20 changes: 16 additions & 4 deletions invokeai/backend/model_manager/load/model_cache/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,15 @@ def get(

return cache_entry

def lock(self, key: str) -> None:
def lock(self, cache_entry: CacheRecord) -> None:
"""Lock a model for use and move it into VRAM."""
cache_entry = self._cached_models[key]
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Locking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)."
)
# cache_entry = self._cached_models[key]
cache_entry.lock()

try:
Expand All @@ -214,9 +220,15 @@ def lock(self, key: str) -> None:
cache_entry.unlock()
raise

def unlock(self, key: str) -> None:
def unlock(self, cache_entry: CacheRecord) -> None:
"""Unlock a model."""
cache_entry = self._cached_models[key]
if cache_entry.key not in self._cached_models:
self._logger.info(
f"Unlocking model cache entry {cache_entry.key} ({cache_entry.model.__class__.__name__}), but it has "
"already been dropped from the RAM cache. This is a sign that the model loading order is non-optimal "
"in the invocation code (See https://github.com/invoke-ai/InvokeAI/issues/7513)."
)
# cache_entry = self._cached_models[key]
cache_entry.unlock()
if not self._lazy_offloading:
self._offload_unlocked_models(0)
Expand Down

0 comments on commit 782ee7a

Please sign in to comment.