Skip to content

Commit

Permalink
Add keep_ram_copy_of_weights config option (#7565)
Browse files Browse the repository at this point in the history
## Summary

This PR adds a `keep_ram_copy_of_weights` config option the default (and
legacy) behavior is `true`. The tradeoffs for this setting are as
follows:
- `keep_ram_copy_of_weights: true`: Faster model switching and LoRA
patching.
- `keep_ram_copy_of_weights: false`: Lower average RAM load (may not
help significantly with peak RAM).

## Related Issues / Discussions

- Helps with #7563
- The Low-VRAM docs are updated to include this feature in
#7566

## QA Instructions

- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: false` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
false`.
  - [x] RAM usage when model is loaded is reduced.
  - [x] Model loading / unloading works as expected.
  - [x] LoRA patching still works.
- Test with `enable_partial_load: true` and `keep_ram_copy_of_weights:
true`.
  - [x] Behavior should be unchanged.

- [x] Smoke test CPU-only and MPS with default configs.

## Merge Plan

- [x] Merge #7564 first and
change target branch.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
- [ ] _Updated `What's New` copy (if doing a release after this PR)_
  • Loading branch information
RyanJDick authored Jan 17, 2025
2 parents 0abb5ea + e5e848d commit f7511bf
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 53 deletions.
2 changes: 2 additions & 0 deletions invokeai/app/services/config/config_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.
device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.
enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.
keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.
ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.
lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.
Expand Down Expand Up @@ -162,6 +163,7 @@ class InvokeAIAppConfig(BaseSettings):
log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.")
device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.")
enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.")
keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.")
# Deprecated CACHE configs
ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def build_model_manager(
ram_cache = ModelCache(
execution_device_working_mem_gb=app_config.device_working_mem_gb,
enable_partial_loading=app_config.enable_partial_loading,
keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights,
max_ram_cache_size_gb=app_config.max_cache_ram_gb,
max_vram_cache_size_gb=app_config.max_cache_vram_gb,
execution_device=execution_device or TorchDevice.choose_torch_device(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad:
MPS memory, etc.
"""

def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
def __init__(
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
sufficient RAM).
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
Expand All @@ -23,7 +28,7 @@ def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, t

# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
if isinstance(model, torch.nn.Module) and keep_ram_copy:
self._cpu_state_dict = model.state_dict()

self._total_bytes = total_bytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,86 @@ class CachedModelWithPartialLoad:
MPS memory, etc.
"""

def __init__(self, model: torch.nn.Module, compute_device: torch.device):
def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False):
self._model = model
self._compute_device = compute_device

# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
model_state_dict = model.state_dict()
# A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA
# patching. Set to `None` if keep_ram_copy is False.
self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None

# A dictionary of the size of each tensor in the state dict.
# HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for
# consistency in case the application code has modified the model's size (e.g. by casting to a different
# precision). Of course, this means that we are making model cache load/unload decisions based on model size
# data that may not be fully accurate.
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()}
self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()}

self._total_bytes = sum(self._state_dict_bytes.values())
self._cur_vram_bytes: int | None = None

self._modules_that_support_autocast = self._find_modules_that_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)

def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore

def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]:
keys_in_modules_that_do_not_support_autocast: set[str] = set()
for key in self._cpu_state_dict.keys():
for key in state_dict.keys():
for module_name in self._modules_that_support_autocast.keys():
if key.startswith(module_name):
break
else:
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast

def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
"""A helper function that groups state dict keys by module prefix.
Example:
```
state_dict = {
"weight": ...,
"module.submodule.weight": ...,
"module.submodule.bias": ...,
"module.other_submodule.weight": ...,
"module.other_submodule.bias": ...,
}
output = group_state_dict_keys_by_module_prefix(state_dict)
# The output will be:
output = {
"": [
"weight",
],
"module.submodule": [
"module.submodule.weight",
"module.submodule.bias",
],
"module.other_submodule": [
"module.other_submodule.weight",
"module.other_submodule.bias",
],
}
```
"""
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
for key in state_dict.keys():
split = key.rsplit(".", 1)
# `split` will have length 1 if the root module has parameters.
module_name = split[0] if len(split) > 1 else ""
if module_name not in state_dict_keys_by_module_prefix:
state_dict_keys_by_module_prefix[module_name] = []
state_dict_keys_by_module_prefix[module_name].append(key)
return state_dict_keys_by_module_prefix

def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
Expand Down Expand Up @@ -98,6 +144,82 @@ def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())

def _load_state_dict_with_device_conversion(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):
if self._cpu_state_dict is not None:
# Run the fast version.
self._load_state_dict_with_fast_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
cpu_state_dict=self._cpu_state_dict,
)
else:
# Run the low-virtual-memory version.
self._load_state_dict_with_jit_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
)

def _load_state_dict_with_jit_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
):
"""A custom state dict loading implementation with good peak memory properties.
This implementation has the important property that it copies parameters to the target device one module at a time
rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the
peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights
and CUDA weights simultaneously, because Windows will reserve virtual memory for both.
"""
for module_name, module in self._model.named_modules():
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
# Calculate the length of the module name prefix.
prefix_len = len(module_name)
if prefix_len > 0:
prefix_len += 1

module_state_dict = {}
for key in module_keys:
if key in keys_to_convert:
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
# parameter.
state_dict[key] = state_dict[key].to(target_device)
# Note that we keep parameters that have not been moved to a new device in case the module implements
# weird custom state dict loading logic that requires all parameters to be present.
module_state_dict[key[prefix_len:]] = state_dict[key]

if len(module_state_dict) > 0:
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
# state dict that only contains the parameters of `module` (not its children).
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
# modules will recurse through all of the children, so is a bit wasteful.
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
# Missing keys are ok, unexpected keys are not.
assert len(incompatible_keys.unexpected_keys) == 0

def _load_state_dict_with_fast_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
cpu_state_dict: dict[str, torch.Tensor],
):
"""Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed
up transfers of weights to the CPU.
"""
for key in keys_to_convert:
if target_device.type == "cpu":
state_dict[key] = cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)

self._model.load_state_dict(state_dict, assign=True)

@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Expand All @@ -112,26 +234,33 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:

cur_state_dict = self._model.state_dict()

# Identify the keys that will be loaded into VRAM.
keys_to_load: set[str] = set()

# First, process the keys that *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue

keys_to_load.add(key)
param_size = self._state_dict_bytes[key]
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size

if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)

# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
# Skip the keys that have already been processed above.
if key in keys_to_load:
continue

if param.device.type == self._compute_device.type:
continue

Expand All @@ -142,14 +271,14 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
fully_loaded = False
continue

cur_state_dict[key] = param.to(self._compute_device, copy=True)
keys_to_load.add(key)
vram_bytes_loaded += param_size

if vram_bytes_loaded > 0:
if len(keys_to_load) > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device)

if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
Expand Down Expand Up @@ -180,6 +309,10 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight

offload_device = "cpu"
cur_state_dict = self._model.state_dict()

# Identify the keys that will be offloaded to CPU.
keys_to_offload: set[str] = set()

for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break
Expand All @@ -191,11 +324,11 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight
required_weights_in_vram += self._state_dict_bytes[key]
continue

cur_state_dict[key] = self._cpu_state_dict[key]
keys_to_offload.add(key)
vram_bytes_freed += self._state_dict_bytes[key]

if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)
if len(keys_to_offload) > 0:
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu"))

if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
Expand Down
Loading

0 comments on commit f7511bf

Please sign in to comment.