From 04087c38cee2ab463f72e2a5f2710f3c68e4dd16 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 14 Jan 2025 16:09:35 +0000 Subject: [PATCH 1/5] Add keep_ram_copy option to CachedModelWithPartialLoad. --- .../cached_model_with_partial_load.py | 24 ++++-- .../load/model_cache/model_cache.py | 2 +- .../test_cached_model_with_partial_load.py | 77 ++++++++++++++----- 3 files changed, 73 insertions(+), 30 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 3c069c975d9..7eaced7396e 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -14,33 +14,37 @@ class CachedModelWithPartialLoad: MPS memory, etc. """ - def __init__(self, model: torch.nn.Module, compute_device: torch.device): + def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ram_copy: bool = False): self._model = model self._compute_device = compute_device - # A CPU read-only copy of the model's state dict. - self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict() + model_state_dict = model.state_dict() + # A CPU read-only copy of the model's state dict. Used for faster model unloads from VRAM, and to speed up LoRA + # patching. Set to `None` if keep_ram_copy is False. + self._cpu_state_dict: dict[str, torch.Tensor] | None = model_state_dict if keep_ram_copy else None # A dictionary of the size of each tensor in the state dict. # HACK(ryand): We use this dictionary any time we are doing byte tracking calculations. We do this for # consistency in case the application code has modified the model's size (e.g. by casting to a different # precision). Of course, this means that we are making model cache load/unload decisions based on model size # data that may not be fully accurate. - self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in self._cpu_state_dict.items()} + self._state_dict_bytes = {k: calc_tensor_size(v) for k, v in model_state_dict.items()} self._total_bytes = sum(self._state_dict_bytes.values()) self._cur_vram_bytes: int | None = None self._modules_that_support_autocast = self._find_modules_that_support_autocast() - self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast() + self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast( + model_state_dict + ) def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore - def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]: + def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[str, torch.Tensor]) -> set[str]: keys_in_modules_that_do_not_support_autocast: set[str] = set() - for key in self._cpu_state_dict.keys(): + for key in state_dict.keys(): for module_name in self._modules_that_support_autocast.keys(): if key.startswith(module_name): break @@ -191,7 +195,11 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight required_weights_in_vram += self._state_dict_bytes[key] continue - cur_state_dict[key] = self._cpu_state_dict[key] + if self._cpu_state_dict is not None: + cur_state_dict[key] = self._cpu_state_dict[key] + else: + cur_state_dict[key] = param.to("cpu") + vram_bytes_freed += self._state_dict_bytes[key] if vram_bytes_freed > 0: diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index bf51b974ce3..cf33cc9cfe9 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -154,7 +154,7 @@ def put(self, key: str, model: AnyModel) -> None: # Wrap model. if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading: - wrapped_model = CachedModelWithPartialLoad(model, self._execution_device) + wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False) else: wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size) diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index a3a1537c3dd..3b70d6b71d1 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -20,9 +20,15 @@ def model(): return model +parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False]) + + @parameterize_mps_and_cuda -def test_cached_model_total_bytes(device: str, model: DummyModule): - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) +@parameterize_keep_ram_copy +def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool): + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) linear1_numel = 10 * 32 + 32 linear2_numel = 32 * 64 + 64 buffer1_numel = 64 @@ -31,9 +37,12 @@ def test_cached_model_total_bytes(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_cur_vram_bytes(device: str, model: DummyModule): +@parameterize_keep_ram_copy +def test_cached_model_cur_vram_bytes(device: str, model: DummyModule, keep_ram_copy: bool): # Model starts in CPU memory. - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) assert cached_model.cur_vram_bytes() == 0 # Full load the model into VRAM. @@ -45,9 +54,12 @@ def test_cached_model_cur_vram_bytes(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_partial_load(device: str, model: DummyModule): +@parameterize_keep_ram_copy +def test_cached_model_partial_load(device: str, model: DummyModule, keep_ram_copy: bool): # Model starts in CPU memory. - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 @@ -71,9 +83,12 @@ def test_cached_model_partial_load(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_partial_unload(device: str, model: DummyModule): +@parameterize_keep_ram_copy +def test_cached_model_partial_unload(device: str, model: DummyModule, keep_ram_copy: bool): # Model starts in CPU memory. - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 @@ -99,9 +114,14 @@ def test_cached_model_partial_unload(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, model: DummyModule): +@parameterize_keep_ram_copy +def test_cached_model_partial_unload_keep_required_weights_in_vram( + device: str, model: DummyModule, keep_ram_copy: bool +): # Model starts in CPU memory. - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 @@ -130,8 +150,11 @@ def test_cached_model_partial_unload_keep_required_weights_in_vram(device: str, @parameterize_mps_and_cuda -def test_cached_model_full_load_and_unload(device: str, model: DummyModule): - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) +@parameterize_keep_ram_copy +def test_cached_model_full_load_and_unload(device: str, model: DummyModule, keep_ram_copy: bool): + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) # Model starts in CPU memory. model_total_bytes = cached_model.total_bytes() @@ -162,8 +185,11 @@ def test_cached_model_full_load_and_unload(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_full_load_from_partial(device: str, model: DummyModule): - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) +@parameterize_keep_ram_copy +def test_cached_model_full_load_from_partial(device: str, model: DummyModule, keep_ram_copy: bool): + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) # Model starts in CPU memory. model_total_bytes = cached_model.total_bytes() @@ -190,8 +216,11 @@ def test_cached_model_full_load_from_partial(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_full_unload_from_partial(device: str, model: DummyModule): - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) +@parameterize_keep_ram_copy +def test_cached_model_full_unload_from_partial(device: str, model: DummyModule, keep_ram_copy: bool): + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) # Model starts in CPU memory. model_total_bytes = cached_model.total_bytes() @@ -219,7 +248,7 @@ def test_cached_model_full_unload_from_partial(device: str, model: DummyModule): @parameterize_mps_and_cuda def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule): - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device), keep_ram_copy=True) # Model starts in CPU memory. assert cached_model.cur_vram_bytes() == 0 @@ -242,8 +271,11 @@ def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_full_load_and_inference(device: str, model: DummyModule): - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) +@parameterize_keep_ram_copy +def test_cached_model_full_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool): + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) # Model starts in CPU memory. model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 @@ -269,9 +301,12 @@ def test_cached_model_full_load_and_inference(device: str, model: DummyModule): @parameterize_mps_and_cuda -def test_cached_model_partial_load_and_inference(device: str, model: DummyModule): +@parameterize_keep_ram_copy +def test_cached_model_partial_load_and_inference(device: str, model: DummyModule, keep_ram_copy: bool): # Model starts in CPU memory. - cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) + cached_model = CachedModelWithPartialLoad( + model=model, compute_device=torch.device(device), keep_ram_copy=keep_ram_copy + ) model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 From c76d08d1fd05e1a55fa71946b12e3339c29dba63 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 16 Jan 2025 15:08:23 +0000 Subject: [PATCH 2/5] Add keep_ram_copy option to CachedModelOnlyFullLoad. --- .../cached_model_only_full_load.py | 9 +++- .../test_cached_model_only_full_load.py | 45 ++++++++++++++----- .../test_cached_model_with_partial_load.py | 9 ++-- .../load/model_cache/cached_model/utils.py | 2 + 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py index 719a559dd02..be398fa1295 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_only_full_load.py @@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad: MPS memory, etc. """ - def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int): + def __init__( + self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False + ): """Initialize a CachedModelOnlyFullLoad. Args: model (torch.nn.Module | Any): The model to wrap. Should be on the CPU. compute_device (torch.device): The compute device to move the model to. total_bytes (int): The total size (in bytes) of all the weights in the model. + keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy + increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is + sufficient RAM). """ # model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases. self._model = model @@ -23,7 +28,7 @@ def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, t # A CPU read-only copy of the model's state dict. self._cpu_state_dict: dict[str, torch.Tensor] | None = None - if isinstance(model, torch.nn.Module): + if isinstance(model, torch.nn.Module) and keep_ram_copy: self._cpu_state_dict = model.state_dict() self._total_bytes = total_bytes diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py index 76a3774288c..509d6494a92 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_only_full_load.py @@ -3,7 +3,11 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import ( CachedModelOnlyFullLoad, ) -from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda +from tests.backend.model_manager.load.model_cache.cached_model.utils import ( + DummyModule, + parameterize_keep_ram_copy, + parameterize_mps_and_cuda, +) class NonTorchModel: @@ -17,16 +21,22 @@ def run_inference(self, x: torch.Tensor) -> torch.Tensor: @parameterize_mps_and_cuda -def test_cached_model_total_bytes(device: str): +@parameterize_keep_ram_copy +def test_cached_model_total_bytes(device: str, keep_ram_copy: bool): model = DummyModule() - cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy + ) assert cached_model.total_bytes() == 100 @parameterize_mps_and_cuda -def test_cached_model_is_in_vram(device: str): +@parameterize_keep_ram_copy +def test_cached_model_is_in_vram(device: str, keep_ram_copy: bool): model = DummyModule() - cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy + ) assert not cached_model.is_in_vram() assert cached_model.cur_vram_bytes() == 0 @@ -40,9 +50,12 @@ def test_cached_model_is_in_vram(device: str): @parameterize_mps_and_cuda -def test_cached_model_full_load_and_unload(device: str): +@parameterize_keep_ram_copy +def test_cached_model_full_load_and_unload(device: str, keep_ram_copy: bool): model = DummyModule() - cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy + ) assert cached_model.full_load_to_vram() == 100 assert cached_model.is_in_vram() assert all(p.device.type == device for p in cached_model.model.parameters()) @@ -55,7 +68,9 @@ def test_cached_model_full_load_and_unload(device: str): @parameterize_mps_and_cuda def test_cached_model_get_cpu_state_dict(device: str): model = DummyModule() - cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=True + ) assert not cached_model.is_in_vram() # The CPU state dict can be accessed and has the expected properties. @@ -76,9 +91,12 @@ def test_cached_model_get_cpu_state_dict(device: str): @parameterize_mps_and_cuda -def test_cached_model_full_load_and_inference(device: str): +@parameterize_keep_ram_copy +def test_cached_model_full_load_and_inference(device: str, keep_ram_copy: bool): model = DummyModule() - cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy + ) assert not cached_model.is_in_vram() # Run inference on the CPU. @@ -99,9 +117,12 @@ def test_cached_model_full_load_and_inference(device: str): @parameterize_mps_and_cuda -def test_non_torch_model(device: str): +@parameterize_keep_ram_copy +def test_non_torch_model(device: str, keep_ram_copy: bool): model = NonTorchModel() - cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100) + cached_model = CachedModelOnlyFullLoad( + model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy + ) assert not cached_model.is_in_vram() # The model does not have a CPU state dict. diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index 3b70d6b71d1..7ae45ff2f53 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -10,7 +10,11 @@ apply_custom_layers_to_model, ) from invokeai.backend.util.calc_tensor_size import calc_tensor_size -from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda +from tests.backend.model_manager.load.model_cache.cached_model.utils import ( + DummyModule, + parameterize_keep_ram_copy, + parameterize_mps_and_cuda, +) @pytest.fixture @@ -20,9 +24,6 @@ def model(): return model -parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False]) - - @parameterize_mps_and_cuda @parameterize_keep_ram_copy def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool): diff --git a/tests/backend/model_manager/load/model_cache/cached_model/utils.py b/tests/backend/model_manager/load/model_cache/cached_model/utils.py index 9554299e066..845d3c90abd 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/utils.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/utils.py @@ -29,3 +29,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), ], ) + +parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False]) From 36a3869af03ae46cbf53a28064293a038933b389 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 16 Jan 2025 15:14:52 +0000 Subject: [PATCH 3/5] Add keep_ram_copy_of_weights config option. --- invokeai/app/services/config/config_default.py | 2 ++ .../services/model_manager/model_manager_default.py | 1 + .../model_manager/load/model_cache/model_cache.py | 10 ++++++++-- tests/backend/model_manager/model_manager_fixtures.py | 1 + 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 6d95f2f7fa9..c3d78bc52c6 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -87,6 +87,7 @@ class InvokeAIAppConfig(BaseSettings): log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value. enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. + keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high, set it to True for improved speed if there is RAM to spare. ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable. vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable. lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable. @@ -162,6 +163,7 @@ class InvokeAIAppConfig(BaseSettings): log_memory_usage: bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.") device_working_mem_gb: float = Field(default=3, description="The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value.") enable_partial_loading: bool = Field(default=False, description="Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM.") + keep_ram_copy_of_weights: bool = Field(default=True, description="Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high.") # Deprecated CACHE configs ram: Optional[float] = Field(default=None, gt=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.") vram: Optional[float] = Field(default=None, ge=0, description="DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable.") diff --git a/invokeai/app/services/model_manager/model_manager_default.py b/invokeai/app/services/model_manager/model_manager_default.py index cec3b0bc18b..9ad10c5e737 100644 --- a/invokeai/app/services/model_manager/model_manager_default.py +++ b/invokeai/app/services/model_manager/model_manager_default.py @@ -84,6 +84,7 @@ def build_model_manager( ram_cache = ModelCache( execution_device_working_mem_gb=app_config.device_working_mem_gb, enable_partial_loading=app_config.enable_partial_loading, + keep_ram_copy_of_weights=app_config.keep_ram_copy_of_weights, max_ram_cache_size_gb=app_config.max_cache_ram_gb, max_vram_cache_size_gb=app_config.max_cache_vram_gb, execution_device=execution_device or TorchDevice.choose_torch_device(), diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index cf33cc9cfe9..5ef6aefe613 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -78,6 +78,7 @@ def __init__( self, execution_device_working_mem_gb: float, enable_partial_loading: bool, + keep_ram_copy_of_weights: bool, max_ram_cache_size_gb: float | None = None, max_vram_cache_size_gb: float | None = None, execution_device: torch.device | str = "cuda", @@ -105,6 +106,7 @@ def __init__( :param logger: InvokeAILogger to use (otherwise creates one) """ self._enable_partial_loading = enable_partial_loading + self._keep_ram_copy_of_weights = keep_ram_copy_of_weights self._execution_device_working_mem_gb = execution_device_working_mem_gb self._execution_device: torch.device = torch.device(execution_device) self._storage_device: torch.device = torch.device(storage_device) @@ -154,9 +156,13 @@ def put(self, key: str, model: AnyModel) -> None: # Wrap model. if isinstance(model, torch.nn.Module) and running_with_cuda and self._enable_partial_loading: - wrapped_model = CachedModelWithPartialLoad(model, self._execution_device, keep_ram_copy=False) + wrapped_model = CachedModelWithPartialLoad( + model, self._execution_device, keep_ram_copy=self._keep_ram_copy_of_weights + ) else: - wrapped_model = CachedModelOnlyFullLoad(model, self._execution_device, size) + wrapped_model = CachedModelOnlyFullLoad( + model, self._execution_device, size, keep_ram_copy=self._keep_ram_copy_of_weights + ) cache_record = CacheRecord(key=key, cached_model=wrapped_model) self._cached_models[key] = cache_record diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 4a91ea70f4d..87d617662e2 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -94,6 +94,7 @@ def mm2_loader(mm2_app_config: InvokeAIAppConfig) -> ModelLoadServiceBase: ram_cache = ModelCache( execution_device_working_mem_gb=mm2_app_config.device_working_mem_gb, enable_partial_loading=mm2_app_config.enable_partial_loading, + keep_ram_copy_of_weights=mm2_app_config.keep_ram_copy_of_weights, max_ram_cache_size_gb=mm2_app_config.max_cache_ram_gb, max_vram_cache_size_gb=mm2_app_config.max_cache_vram_gb, execution_device=TorchDevice.choose_torch_device(), From da589b3f1f2ee50fbc2f18e446e90cc040901899 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 14 Jan 2025 21:36:55 +0000 Subject: [PATCH 4/5] Memory optimization to load state dicts one module at a time in CachedModelWithPartialLoad when we are not storing a CPU copy of the state dict (i.e. when keep_ram_copy_of_weights=False). --- .../cached_model_with_partial_load.py | 149 ++++++++++++++++-- 1 file changed, 137 insertions(+), 12 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index 7eaced7396e..004943c0174 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -37,6 +37,7 @@ def __init__(self, model: torch.nn.Module, compute_device: torch.device, keep_ra self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast( model_state_dict ) + self._state_dict_keys_by_module_prefix = self._group_state_dict_keys_by_module_prefix(model_state_dict) def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" @@ -52,6 +53,47 @@ def _find_keys_in_modules_that_do_not_support_autocast(self, state_dict: dict[st keys_in_modules_that_do_not_support_autocast.add(key) return keys_in_modules_that_do_not_support_autocast + def _group_state_dict_keys_by_module_prefix(self, state_dict: dict[str, torch.Tensor]) -> dict[str, list[str]]: + """A helper function that groups state dict keys by module prefix. + + Example: + ``` + state_dict = { + "weight": ..., + "module.submodule.weight": ..., + "module.submodule.bias": ..., + "module.other_submodule.weight": ..., + "module.other_submodule.bias": ..., + } + + output = group_state_dict_keys_by_module_prefix(state_dict) + + # The output will be: + output = { + "": [ + "weight", + ], + "module.submodule": [ + "module.submodule.weight", + "module.submodule.bias", + ], + "module.other_submodule": [ + "module.other_submodule.weight", + "module.other_submodule.bias", + ], + } + ``` + """ + state_dict_keys_by_module_prefix: dict[str, list[str]] = {} + for key in state_dict.keys(): + split = key.rsplit(".", 1) + # `split` will have length 1 if the root module has parameters. + module_name = split[0] if len(split) > 1 else "" + if module_name not in state_dict_keys_by_module_prefix: + state_dict_keys_by_module_prefix[module_name] = [] + state_dict_keys_by_module_prefix[module_name].append(key) + return state_dict_keys_by_module_prefix + def _move_non_persistent_buffers_to_device(self, device: torch.device): """Move the non-persistent buffers to the target device. These buffers are not included in the state dict, so we need to move them manually. @@ -102,6 +144,82 @@ def full_unload_from_vram(self) -> int: """Unload all weights from VRAM.""" return self.partial_unload_from_vram(self.total_bytes()) + def _load_state_dict_with_device_conversion( + self, state_dict: dict[str, torch.Tensor], keys_to_convert: set[str], target_device: torch.device + ): + if self._cpu_state_dict is not None: + # Run the fast version. + self._load_state_dict_with_fast_device_conversion( + state_dict=state_dict, + keys_to_convert=keys_to_convert, + target_device=target_device, + cpu_state_dict=self._cpu_state_dict, + ) + else: + # Run the low-virtual-memory version. + self._load_state_dict_with_jit_device_conversion( + state_dict=state_dict, + keys_to_convert=keys_to_convert, + target_device=target_device, + ) + + def _load_state_dict_with_jit_device_conversion( + self, + state_dict: dict[str, torch.Tensor], + keys_to_convert: set[str], + target_device: torch.device, + ): + """A custom state dict loading implementation with good peak memory properties. + + This implementation has the important property that it copies parameters to the target device one module at a time + rather than applying all of the device conversions and then calling load_state_dict(). This is done to minimize the + peak virtual memory usage. Specifically, we want to avoid a case where we hold references to all of the CPU weights + and CUDA weights simultaneously, because Windows will reserve virtual memory for both. + """ + for module_name, module in self._model.named_modules(): + module_keys = self._state_dict_keys_by_module_prefix.get(module_name, []) + # Calculate the length of the module name prefix. + prefix_len = len(module_name) + if prefix_len > 0: + prefix_len += 1 + + module_state_dict = {} + for key in module_keys: + if key in keys_to_convert: + # It is important that we overwrite `state_dict[key]` to avoid keeping two copies of the same + # parameter. + state_dict[key] = state_dict[key].to(target_device) + # Note that we keep parameters that have not been moved to a new device in case the module implements + # weird custom state dict loading logic that requires all parameters to be present. + module_state_dict[key[prefix_len:]] = state_dict[key] + + if len(module_state_dict) > 0: + # We set strict=False, because if `module` has both parameters and child modules, then we are loading a + # state dict that only contains the parameters of `module` (not its children). + # We assume that it is rare for non-leaf modules to have parameters. Calling load_state_dict() on non-leaf + # modules will recurse through all of the children, so is a bit wasteful. + incompatible_keys = module.load_state_dict(module_state_dict, strict=False, assign=True) + # Missing keys are ok, unexpected keys are not. + assert len(incompatible_keys.unexpected_keys) == 0 + + def _load_state_dict_with_fast_device_conversion( + self, + state_dict: dict[str, torch.Tensor], + keys_to_convert: set[str], + target_device: torch.device, + cpu_state_dict: dict[str, torch.Tensor], + ): + """Convert parameters to the target device and load them into the model. Leverages the `cpu_state_dict` to speed + up transfers of weights to the CPU. + """ + for key in keys_to_convert: + if target_device.type == "cpu": + state_dict[key] = cpu_state_dict[key] + else: + state_dict[key] = state_dict[key].to(target_device) + + self._model.load_state_dict(state_dict, assign=True) + @torch.no_grad() def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: """Load more weights into VRAM without exceeding vram_bytes_to_load. @@ -116,26 +234,33 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: cur_state_dict = self._model.state_dict() + # Identify the keys that will be loaded into VRAM. + keys_to_load: set[str] = set() + # First, process the keys that *must* be loaded into VRAM. for key in self._keys_in_modules_that_do_not_support_autocast: param = cur_state_dict[key] if param.device.type == self._compute_device.type: continue + keys_to_load.add(key) param_size = self._state_dict_bytes[key] - cur_state_dict[key] = param.to(self._compute_device, copy=True) vram_bytes_loaded += param_size if vram_bytes_loaded > vram_bytes_to_load: logger = InvokeAILogger.get_logger() logger.warning( - f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were " + f"Loading {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were " "requested. This is the minimum set of weights in VRAM required to run the model." ) # Next, process the keys that can optionally be loaded into VRAM. fully_loaded = True for key, param in cur_state_dict.items(): + # Skip the keys that have already been processed above. + if key in keys_to_load: + continue + if param.device.type == self._compute_device.type: continue @@ -146,14 +271,14 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: fully_loaded = False continue - cur_state_dict[key] = param.to(self._compute_device, copy=True) + keys_to_load.add(key) vram_bytes_loaded += param_size - if vram_bytes_loaded > 0: + if len(keys_to_load) > 0: # We load the entire state dict, not just the parameters that changed, in case there are modules that # override _load_from_state_dict() and do some funky stuff that requires the entire state dict. # Alternatively, in the future, grouping parameters by module could probably solve this problem. - self._model.load_state_dict(cur_state_dict, assign=True) + self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_load, self._compute_device) if self._cur_vram_bytes is not None: self._cur_vram_bytes += vram_bytes_loaded @@ -184,6 +309,10 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight offload_device = "cpu" cur_state_dict = self._model.state_dict() + + # Identify the keys that will be offloaded to CPU. + keys_to_offload: set[str] = set() + for key, param in cur_state_dict.items(): if vram_bytes_freed >= vram_bytes_to_free: break @@ -195,15 +324,11 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int, keep_required_weight required_weights_in_vram += self._state_dict_bytes[key] continue - if self._cpu_state_dict is not None: - cur_state_dict[key] = self._cpu_state_dict[key] - else: - cur_state_dict[key] = param.to("cpu") - + keys_to_offload.add(key) vram_bytes_freed += self._state_dict_bytes[key] - if vram_bytes_freed > 0: - self._model.load_state_dict(cur_state_dict, assign=True) + if len(keys_to_offload) > 0: + self._load_state_dict_with_device_conversion(cur_state_dict, keys_to_offload, torch.device("cpu")) if self._cur_vram_bytes is not None: self._cur_vram_bytes -= vram_bytes_freed From e5e848d2399002f2fe16296f3bbb4be29a080786 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 16 Jan 2025 22:34:23 +0000 Subject: [PATCH 5/5] Update config docstring. --- invokeai/app/services/config/config_default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index c3d78bc52c6..4cc6aa720f6 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -87,7 +87,7 @@ class InvokeAIAppConfig(BaseSettings): log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour. device_working_mem_gb: The amount of working memory to keep available on the compute device (in GB). Has no effect if running on CPU. If you are experiencing OOM errors, try increasing this value. enable_partial_loading: Enable partial loading of models. This enables models to run with reduced VRAM requirements (at the cost of slower speed) by streaming the model from RAM to VRAM as its used. In some edge cases, partial loading can cause models to run more slowly if they were previously being fully loaded into VRAM. - keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high, set it to True for improved speed if there is RAM to spare. + keep_ram_copy_of_weights: Whether to keep a full RAM copy of a model's weights when the model is loaded in VRAM. Keeping a RAM copy increases average RAM usage, but speeds up model switching and LoRA patching (assuming there is sufficient RAM). Set this to False if RAM pressure is consistently high. ram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_ram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable. vram: DEPRECATED: This setting is no longer used. It has been replaced by `max_cache_vram_gb`, but most users will not need to use this config since automatic cache size limits should work well in most cases. This config setting will be removed once the new model cache behavior is stable. lazy_offload: DEPRECATED: This setting is no longer used. Lazy-offloading is enabled by default. This config setting will be removed once the new model cache behavior is stable.