Skip to content

Commit

Permalink
Memory optimization to load state dicts one module at a time in Cache…
Browse files Browse the repository at this point in the history
…dModelWithPartialLoad when we are not storing a CPU copy of the state dict (i.e. when keep_ram_copy_of_weights=False).
  • Loading branch information
RyanJDick committed Jan 16, 2025
1 parent 36a3869 commit da589b3
Showing 1 changed file with 137 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ra
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast(
model_state_dict
)
self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict)

def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
"""Find all modules that support autocasting."""
Expand All @@ -52,6 +53,47 @@ def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[st
keys_in_modules_that_do_not_support_autocast.add(key)
return keys_in_modules_that_do_not_support_autocast

def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]:
"""A helper function that groups state dict keys by module prefix.
Example:
```
state_dict = {
"weight": ...,
"module.submodule.weight": ...,
"module.submodule.bias": ...,
"module.other_submodule.weight": ...,
"module.other_submodule.bias": ...,
}
output = group_state_dict_keys_by_module_prefix(state_dict)
# The output will be:
output = {
"": [
"weight",
],
"module.submodule": [
"module.submodule.weight",
"module.submodule.bias",
],
"module.other_submodule": [
"module.other_submodule.weight",
"module.other_submodule.bias",
],
}
```
"""
state_dict_keys_by_module_prefix: dict[str, list[str]] = {}
for key in state_dict.keys():
split = key.rsplit(".", 1)
# `split` will have length 1 if the root module has parameters.
module_name = split[0] if len(split) > 1 else ""
if module_name not in state_dict_keys_by_module_prefix:
state_dict_keys_by_module_prefix[module_name] = []
state_dict_keys_by_module_prefix[module_name].append(key)
return state_dict_keys_by_module_prefix

def _move_non_persistent_buffers_to_device(self, device: torch.device):
"""Move the non-persistent buffers to the target device. These buffers are not included in the state dict,
so we need to move them manually.
Expand Down Expand Up @@ -102,6 +144,82 @@ def full_unload_from_vram(self) -> int:
"""Unload all weights from VRAM."""
return self.partial_unload_from_vram(self.total_bytes())

def _load_state_dict_with_device_conversion(
self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device
):
if self._cpu_state_dict is not None:
# Run the fast version.
self._load_state_dict_with_fast_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
cpu_state_dict=self._cpu_state_dict,
)
else:
# Run the low-virtual-memory version.
self._load_state_dict_with_jit_device_conversion(
state_dict=state_dict,
keys_to_convert=keys_to_convert,
target_device=target_device,
)

def _load_state_dict_with_jit_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
):
"""A custom state dict loading implementation with good peak memory properties.
This implementation has the important property that it copies parameters to the target device one module at a time
rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the
peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights
and CUDA weights simultaneously, because Windows will reserve virtual memory for both.
"""
for module_name, module in self._model.named_modules():
module_keys = self._state_dict_keys_by_module_prefix.get(module_name, [])
# Calculate the length of the module name prefix.
prefix_len = len(module_name)
if prefix_len > 0:
prefix_len += 1

module_state_dict = {}
for key in module_keys:
if key in keys_to_convert:
# It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same
# parameter.
state_dict[key] = state_dict[key].to(target_device)
# Note that we keep parameters that have not been moved to a new device in case the module implements
# weird custom state dict loading logic that requires all parameters to be present.
module_state_dict[key[prefix_len:]] = state_dict[key]

if len(module_state_dict) > 0:
# We set strict=False, because if `module` has both parameters and child modules, then we are loading a
# state dict that only contains the parameters of `module` (not its children).
# We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf
# modules will recurse through all of the children, so is a bit wasteful.
incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True)
# Missing keys are ok, unexpected keys are not.
assert len(incompatible_keys.unexpected_keys) == 0

def _load_state_dict_with_fast_device_conversion(
self,
state_dict: dict[str, torch.Tensor],
keys_to_convert: set[str],
target_device: torch.device,
cpu_state_dict: dict[str, torch.Tensor],
):
"""Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed
up transfers of weights to the CPU.
"""
for key in keys_to_convert:
if target_device.type == "cpu":
state_dict[key] = cpu_state_dict[key]
else:
state_dict[key] = state_dict[key].to(target_device)

self._model.load_state_dict(state_dict, assign=True)

@torch.no_grad()
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
Expand All @@ -116,26 +234,33 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:

cur_state_dict = self._model.state_dict()

# Identify the keys that will be loaded into VRAM.
keys_to_load: set[str] = set()

# First, process the keys that *must* be loaded into VRAM.
for key in self._keys_in_modules_that_do_not_support_autocast:
param = cur_state_dict[key]
if param.device.type == self._compute_device.type:
continue

keys_to_load.add(key)
param_size = self._state_dict_bytes[key]
cur_state_dict[key] = param.to(self._compute_device, copy=True)
vram_bytes_loaded += param_size

if vram_bytes_loaded > vram_bytes_to_load:
logger = InvokeAILogger.get_logger()
logger.warning(
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
"requested. This is the minimum set of weights in VRAM required to run the model."
)

# Next, process the keys that can optionally be loaded into VRAM.
fully_loaded = True
for key, param in cur_state_dict.items():
# Skip the keys that have already been processed above.
if key in keys_to_load:
continue

if param.device.type == self._compute_device.type:
continue

Expand All @@ -146,14 +271,14 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
fully_loaded = False
continue

cur_state_dict[key] = param.to(self._compute_device, copy=True)
keys_to_load.add(key)
vram_bytes_loaded += param_size

if vram_bytes_loaded > 0:
if len(keys_to_load) > 0:
# We load the entire state dict, not just the parameters that changed, in case there are modules that
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
self._model.load_state_dict(cur_state_dict, assign=True)
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device)

if self._cur_vram_bytes is not None:
self._cur_vram_bytes += vram_bytes_loaded
Expand Down Expand Up @@ -184,6 +309,10 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight

offload_device = "cpu"
cur_state_dict = self._model.state_dict()

# Identify the keys that will be offloaded to CPU.
keys_to_offload: set[str] = set()

for key, param in cur_state_dict.items():
if vram_bytes_freed >= vram_bytes_to_free:
break
Expand All @@ -195,15 +324,11 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight
required_weights_in_vram += self._state_dict_bytes[key]
continue

if self._cpu_state_dict is not None:
cur_state_dict[key] = self._cpu_state_dict[key]
else:
cur_state_dict[key] = param.to("cpu")

keys_to_offload.add(key)
vram_bytes_freed += self._state_dict_bytes[key]

if vram_bytes_freed > 0:
self._model.load_state_dict(cur_state_dict, assign=True)
if len(keys_to_offload) > 0:
self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu"))

if self._cur_vram_bytes is not None:
self._cur_vram_bytes -= vram_bytes_freed
Expand Down

0 comments on commit da589b3

Please sign in to comment.