From cefcb340d921192ae69d4eb35a6f469a00c79144 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Dec 2024 16:26:34 +0000 Subject: [PATCH 01/31] Add LoRAPatcher.smart_apply_lora_patches() --- invokeai/backend/patches/model_patcher.py | 110 +++++++++++++++++++++ tests/backend/patches/test_lora_patcher.py | 71 ++++++++++++- 2 files changed, 176 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/patches/model_patcher.py b/invokeai/backend/patches/model_patcher.py index 14b92a26a88..aabed0ccdb0 100644 --- a/invokeai/backend/patches/model_patcher.py +++ b/invokeai/backend/patches/model_patcher.py @@ -14,6 +14,116 @@ class LayerPatcher: + @staticmethod + @torch.no_grad() + @contextmanager + def apply_smart_model_patches( + model: torch.nn.Module, + patches: Iterable[Tuple[ModelPatchRaw, float]], + prefix: str, + dtype: torch.dtype, + cached_weights: Optional[Dict[str, torch.Tensor]] = None, + ): + """Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each + module. + """ + + # original_weights are stored for unpatching layers that are directly patched. + original_weights = OriginalWeightsStorage(cached_weights) + # original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper. + original_modules: dict[str, torch.nn.Module] = {} + try: + for patch, patch_weight in patches: + LayerPatcher._apply_smart_model_patch( + model=model, + prefix=prefix, + patch=patch, + patch_weight=patch_weight, + original_weights=original_weights, + original_modules=original_modules, + dtype=dtype, + ) + + yield + finally: + # Restore directly patched layers. + for param_key, weight in original_weights.get_changed_weights(): + model.get_parameter(param_key).copy_(weight) + + # Restore LoRASidecarWrapper modules. + # Note: This logic assumes no nested modules in original_modules. + for module_key, orig_module in original_modules.items(): + module_parent_key, module_name = LayerPatcher._split_parent_key(module_key) + parent_module = model.get_submodule(module_parent_key) + LayerPatcher._set_submodule(parent_module, module_name, orig_module) + + @staticmethod + @torch.no_grad() + def _apply_smart_model_patch( + model: torch.nn.Module, + prefix: str, + patch: ModelPatchRaw, + patch_weight: float, + original_weights: OriginalWeightsStorage, + original_modules: dict[str, torch.nn.Module], + dtype: torch.dtype, + ): + """Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct + patching or a sidecar wrapper for each module. + """ + if patch_weight == 0: + return + + # If the layer keys contain a dot, then they are not flattened, and can be directly used to access model + # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been + # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly + # without searching, but some legacy code still uses flattened keys. + layer_keys_are_flattened = "." not in next(iter(patch.layers.keys())) + + prefix_len = len(prefix) + + for layer_key, layer in patch.layers.items(): + if not layer_key.startswith(prefix): + continue + + module_key, module = LayerPatcher._get_submodule( + model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened + ) + + # Decide whether to use direct patching or a sidecar wrapper. + # Direct patching is preferred, because it results in better runtime speed. + # Reasons to use sidecar patching: + # - The module is already wrapped in a BaseSidecarWrapper. + # - The module is quantized. + # - The module is on the CPU (and we don't want to store a second full copy of the original weights on the + # CPU, since this would double the RAM usage) + # NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller + # and that the caller will use the 'apply_model_sidecar_patches' method if the layer is quantized. + # TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows + # forcing full patching even on the CPU? + if isinstance(module, BaseSidecarWrapper) or LayerPatcher._is_any_part_of_layer_on_cpu(module): + LayerPatcher._apply_model_layer_wrapper_patch( + model=model, + module_to_patch=module, + module_to_patch_key=module_key, + patch=layer, + patch_weight=patch_weight, + original_modules=original_modules, + dtype=dtype, + ) + else: + LayerPatcher._apply_model_layer_patch( + module_to_patch=module, + module_to_patch_key=module_key, + patch=layer, + patch_weight=patch_weight, + original_weights=original_weights, + ) + + @staticmethod + def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool: + return any(p.device.type == "cpu" for p in layer.parameters()) + @staticmethod @torch.no_grad() @contextmanager diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_lora_patcher.py index 057504bb973..5561e882fbc 100644 --- a/tests/backend/patches/test_lora_patcher.py +++ b/tests/backend/patches/test_lora_patcher.py @@ -6,7 +6,7 @@ from invokeai.backend.patches.model_patcher import LayerPatcher -class DummyModule(torch.nn.Module): +class DummyModuleWithOneLayer(torch.nn.Module): def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype): super().__init__() self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) @@ -15,6 +15,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear_layer_1(x) +class DummyModuleWithTwoLayers(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, device: str, dtype: torch.dtype): + super().__init__() + self.linear_layer_1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype) + self.linear_layer_2 = torch.nn.Linear(out_features, out_features, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_layer_2(self.linear_layer_1(x)) + + @pytest.mark.parametrize( ["device", "num_layers"], [ @@ -33,7 +43,7 @@ def test_apply_lora_patches(device: str, num_layers: int): linear_in_features = 4 linear_out_features = 8 lora_rank = 2 - model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=torch.float16) + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16) # Initialize num_layers LoRA models with weights of 0.5. lora_weight = 0.5 @@ -79,7 +89,7 @@ def test_apply_lora_patches_change_device(): linear_out_features = 8 lora_dim = 2 # Initialize the model on the CPU. - model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( @@ -124,7 +134,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int): linear_in_features = 4 linear_out_features = 8 lora_rank = 2 - model = DummyModule(linear_in_features, linear_out_features, device=device, dtype=dtype) + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) # Initialize num_layers LoRA models with weights of 0.5. lora_weight = 0.5 @@ -159,6 +169,57 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int): assert torch.allclose(output_before_patch, output_after_patch) +@pytest.mark.parametrize( + ["device", "num_layers"], + [ + ("cpu", 1), + pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), + ("cpu", 2), + pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), + ], +) +@torch.no_grad() +def test_apply_smart_model_patches(device: str, num_layers: int): + """Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly.""" + dtype = torch.float16 + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) + + # Initialize num_layers LoRA models with weights of 0.5. + lora_weight = 0.5 + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_layers): + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + lora_models.append((lora, lora_weight)) + + # Run inference before patching the model. + input = torch.randn(1, linear_in_features, device=device, dtype=dtype) + output_before_patch = model(input) + + # Patch the model and run inference during the patch. + with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + output_during_patch = model(input) + + # Run inference after unpatching. + output_after_patch = model(input) + + # Check that the output before patching is different from the output during patching. + assert not torch.allclose(output_before_patch, output_during_patch) + + # Check that the output before patching is the same as the output after patching. + assert torch.allclose(output_before_patch, output_after_patch) + + @torch.no_grad() @pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): @@ -167,7 +228,7 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): linear_in_features = 4 linear_out_features = 8 lora_rank = 2 - model = DummyModule(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype) # Initialize num_layers LoRA models with weights of 0.5. lora_weight = 0.5 From d0f35fceed0b7cf72b33f7a1456c26eaadd3f4e0 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Dec 2024 16:38:48 +0000 Subject: [PATCH 02/31] Add test_apply_smart_lora_patches_to_partially_loaded_model(...). --- tests/backend/patches/test_lora_patcher.py | 81 +++++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_lora_patcher.py index 5561e882fbc..dd250b6535d 100644 --- a/tests/backend/patches/test_lora_patcher.py +++ b/tests/backend/patches/test_lora_patcher.py @@ -1,9 +1,13 @@ import pytest import torch +from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( + CachedModelWithPartialLoad, +) from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.patches.model_patcher import LayerPatcher +from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper class DummyModuleWithOneLayer(torch.nn.Module): @@ -220,10 +224,79 @@ def test_apply_smart_model_patches(device: str, num_layers: int): assert torch.allclose(output_before_patch, output_after_patch) +@pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) +@torch.no_grad() +def test_apply_smart_lora_patches_to_partially_loaded_model(num_layers: int): + """Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a + CachedModelWithPartialLoad that is partially loaded into VRAM. + """ + + if not torch.cuda.is_available(): + pytest.skip("requires CUDA device") + + # Initialize the model on the CPU. + dtype = torch.float16 + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda")) + model_total_bytes = cached_model.total_bytes() + assert cached_model.cur_vram_bytes() == 0 + + # Partially load the model into VRAM. + target_vram_bytes = int(model_total_bytes * 0.6) + _ = cached_model.partial_load_to_vram(target_vram_bytes) + assert cached_model.model.linear_layer_1.weight.device.type == "cuda" + assert cached_model.model.linear_layer_2.weight.device.type == "cpu" + + # Initialize num_layers LoRA models with weights of 0.5. + lora_weight = 0.5 + lora_models: list[tuple[ModelPatchRaw, float]] = [] + for _ in range(num_layers): + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ), + "linear_layer_2": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_out_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ), + } + lora = ModelPatchRaw(lora_layers) + lora_models.append((lora, lora_weight)) + + # Run inference before patching the model. + input = torch.randn(1, linear_in_features, device="cuda", dtype=dtype) + output_before_patch = cached_model.model(input) + + # Patch the model and run inference during the patch. + with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype): + # Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not. + assert not isinstance(cached_model.model.linear_layer_1, BaseSidecarWrapper) + assert isinstance(cached_model.model.linear_layer_2, BaseSidecarWrapper) + + output_during_patch = cached_model.model(input) + + # Run inference after unpatching. + output_after_patch = cached_model.model(input) + + # Check that the output before patching is different from the output during patching. + assert not torch.allclose(output_before_patch, output_during_patch) + + # Check that the output before patching is the same as the output after patching. + assert torch.allclose(output_before_patch, output_after_patch) + + @torch.no_grad() @pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) -def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): - """Test that apply_lora_sidecar_patches(...) produces the same model outputs as apply_lora_patches(...).""" +def test_apply_model_sidecar_patches_matches_apply_model_patches(num_layers: int): + """Test that apply_model_sidecar_patches(...) produces the same model outputs as apply__patches(...).""" dtype = torch.float32 linear_in_features = 4 linear_out_features = 8 @@ -253,6 +326,10 @@ def test_apply_lora_sidecar_patches_matches_apply_lora_patches(num_layers: int): with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): output_lora_sidecar_patches = model(input) + with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + output_smart_lora_patches = model(input) + # Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical # differences are tolerable and expected due to the difference between sidecar vs. patching. assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5) + assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5) From 01485120386fbc44d35cb3bdd59478a614c64425 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Dec 2024 16:41:52 +0000 Subject: [PATCH 03/31] (minor) Rename num_layers -> num_loras in unit tests. --- tests/backend/patches/test_lora_patcher.py | 44 +++++++++++----------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_lora_patcher.py index dd250b6535d..14a6b1c7f02 100644 --- a/tests/backend/patches/test_lora_patcher.py +++ b/tests/backend/patches/test_lora_patcher.py @@ -30,7 +30,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize( - ["device", "num_layers"], + ["device", "num_loras"], [ ("cpu", 1), pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), @@ -39,7 +39,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ], ) @torch.no_grad() -def test_apply_lora_patches(device: str, num_layers: int): +def test_apply_lora_patches(device: str, num_loras: int): """Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the correct result, and that model/LoRA tensors are moved between devices as expected. """ @@ -49,10 +49,10 @@ def test_apply_lora_patches(device: str, num_layers: int): lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16) - # Initialize num_layers LoRA models with weights of 0.5. + # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): + for _ in range(num_loras): lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( values={ @@ -65,7 +65,7 @@ def test_apply_lora_patches(device: str, num_layers: int): lora_models.append((lora, lora_weight)) orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() - expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_layers) + expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras) with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): # After patching, all LoRA layer weights should have been moved back to the cpu. @@ -124,7 +124,7 @@ def test_apply_lora_patches_change_device(): @pytest.mark.parametrize( - ["device", "num_layers"], + ["device", "num_loras"], [ ("cpu", 1), pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), @@ -132,7 +132,7 @@ def test_apply_lora_patches_change_device(): pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), ], ) -def test_apply_lora_sidecar_patches(device: str, num_layers: int): +def test_apply_lora_sidecar_patches(device: str, num_loras: int): """Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly.""" dtype = torch.float16 linear_in_features = 4 @@ -140,10 +140,10 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int): lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) - # Initialize num_layers LoRA models with weights of 0.5. + # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): + for _ in range(num_loras): lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( values={ @@ -174,7 +174,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int): @pytest.mark.parametrize( - ["device", "num_layers"], + ["device", "num_loras"], [ ("cpu", 1), pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), @@ -183,7 +183,7 @@ def test_apply_lora_sidecar_patches(device: str, num_layers: int): ], ) @torch.no_grad() -def test_apply_smart_model_patches(device: str, num_layers: int): +def test_apply_smart_model_patches(device: str, num_loras: int): """Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly.""" dtype = torch.float16 linear_in_features = 4 @@ -191,10 +191,10 @@ def test_apply_smart_model_patches(device: str, num_layers: int): lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) - # Initialize num_layers LoRA models with weights of 0.5. + # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): + for _ in range(num_loras): lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( values={ @@ -224,9 +224,9 @@ def test_apply_smart_model_patches(device: str, num_layers: int): assert torch.allclose(output_before_patch, output_after_patch) -@pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) +@pytest.mark.parametrize(["num_loras"], [(1,), (2,)]) @torch.no_grad() -def test_apply_smart_lora_patches_to_partially_loaded_model(num_layers: int): +def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int): """Test the behavior of ModelPatcher.apply_smart_lora_patches(...) when it is applied to a CachedModelWithPartialLoad that is partially loaded into VRAM. """ @@ -250,10 +250,10 @@ def test_apply_smart_lora_patches_to_partially_loaded_model(num_layers: int): assert cached_model.model.linear_layer_1.weight.device.type == "cuda" assert cached_model.model.linear_layer_2.weight.device.type == "cpu" - # Initialize num_layers LoRA models with weights of 0.5. + # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): + for _ in range(num_loras): lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( values={ @@ -294,19 +294,19 @@ def test_apply_smart_lora_patches_to_partially_loaded_model(num_layers: int): @torch.no_grad() -@pytest.mark.parametrize(["num_layers"], [(1,), (2,)]) -def test_apply_model_sidecar_patches_matches_apply_model_patches(num_layers: int): - """Test that apply_model_sidecar_patches(...) produces the same model outputs as apply__patches(...).""" +@pytest.mark.parametrize(["num_loras"], [(1,), (2,)]) +def test_all_patching_methods_produce_same_output(num_loras: int): + """Test that apply_lora_wrapper_patches(...) produces the same model outputs as apply_lora_patches(...).""" dtype = torch.float32 linear_in_features = 4 linear_out_features = 8 lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype) - # Initialize num_layers LoRA models with weights of 0.5. + # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_layers): + for _ in range(num_loras): lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( values={ From 61253b91f1407c0ec8e5dcbb3ca7fb95f11d7e53 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 10 Dec 2024 17:27:33 +0000 Subject: [PATCH 04/31] Enable LoRAPatcher.apply_smart_lora_patches(...) throughout the stack. --- invokeai/app/invocations/compel.py | 8 +++++--- invokeai/app/invocations/denoise_latents.py | 3 ++- invokeai/app/invocations/flux_denoise.py | 3 ++- invokeai/app/invocations/flux_text_encoder.py | 4 +++- invokeai/app/invocations/sd3_text_encoder.py | 4 +++- .../invocations/tiled_multi_diffusion_denoise_latents.py | 4 +++- 6 files changed, 18 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 723cd93a109..2e8cf961076 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -82,10 +82,11 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=text_encoder, patches=_lora_loader(), prefix="lora_te_", + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. @@ -179,10 +180,11 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: # apply all patches while the model is on the target device text_encoder_info.model_on_device() as (cached_weights, text_encoder), tokenizer_info as tokenizer, - LayerPatcher.apply_model_patches( - text_encoder, + LayerPatcher.apply_smart_model_patches( + model=text_encoder, patches=_lora_loader(), prefix=lora_prefix, + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 62ac6934c35..17caae8ba82 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -1003,10 +1003,11 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: ModelPatcher.apply_freeu(unet, self.unet.freeu_config), SeamlessExt.static_patch_model(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=unet, patches=_lora_loader(), prefix="lora_unet_", + dtype=unet.dtype, cached_weights=cached_weights, ), ): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 08bbd9f31c6..c4b84b95838 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -309,10 +309,11 @@ def _run_diffusion( if config.format in [ModelFormat.Checkpoint]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=transformer, patches=self._lora_iterator(context), prefix=FLUX_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index c1113603f0a..b1714629776 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -22,6 +22,7 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo +from invokeai.backend.util.devices import TorchDevice @invocation( @@ -111,10 +112,11 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor: if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context), prefix=FLUX_LORA_CLIP_PREFIX, + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index f92977bd42d..5ca59337886 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -21,6 +21,7 @@ from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo +from invokeai.backend.util.devices import TorchDevice # The SD3 T5 Max Sequence Length set based on the default in diffusers. SD3_T5_MAX_SEQ_LEN = 256 @@ -150,10 +151,11 @@ def _clip_encode( if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. exit_stack.enter_context( - LayerPatcher.apply_model_patches( + LayerPatcher.apply_smart_model_patches( model=clip_text_encoder, patches=self._clip_lora_iterator(context, clip_model), prefix=FLUX_LORA_CLIP_PREFIX, + dtype=TorchDevice.choose_torch_dtype(), cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 761e73d2bf3..2d6ad1758ab 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -207,7 +207,9 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: with ( ExitStack() as exit_stack, unet_info as unet, - LayerPatcher.apply_model_patches(model=unet, patches=_lora_loader(), prefix="lora_unet_"), + LayerPatcher.apply_smart_model_patches( + model=unet, patches=_lora_loader(), prefix="lora_unet_", dtype=unet.dtype + ), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) From 6f926f05b0ae0d9fedf21fd09a0318524d39e7f2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 17 Dec 2024 17:13:45 +0000 Subject: [PATCH 05/31] Update apply_smart_model_patches() so that layer restore matches the behavior of non-smart mode. --- invokeai/backend/patches/model_patcher.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/patches/model_patcher.py b/invokeai/backend/patches/model_patcher.py index aabed0ccdb0..0a7ed628b7f 100644 --- a/invokeai/backend/patches/model_patcher.py +++ b/invokeai/backend/patches/model_patcher.py @@ -48,7 +48,8 @@ def apply_smart_model_patches( finally: # Restore directly patched layers. for param_key, weight in original_weights.get_changed_weights(): - model.get_parameter(param_key).copy_(weight) + cur_param = model.get_parameter(param_key) + cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True) # Restore LoRASidecarWrapper modules. # Note: This logic assumes no nested modules in original_modules. From 80db9537ff74ef324946a13ac163910e8d72c853 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 17 Dec 2024 17:19:12 +0000 Subject: [PATCH 06/31] Rename model_patcher.py -> layer_patcher.py. --- invokeai/app/invocations/compel.py | 2 +- invokeai/app/invocations/denoise_latents.py | 2 +- invokeai/app/invocations/flux_denoise.py | 2 +- invokeai/app/invocations/flux_text_encoder.py | 2 +- invokeai/app/invocations/sd3_text_encoder.py | 2 +- .../app/invocations/tiled_multi_diffusion_denoise_latents.py | 2 +- invokeai/backend/patches/{model_patcher.py => layer_patcher.py} | 0 invokeai/backend/stable_diffusion/extensions/lora.py | 2 +- .../patches/{test_lora_patcher.py => test_layer_patcher.py} | 2 +- 9 files changed, 8 insertions(+), 8 deletions(-) rename invokeai/backend/patches/{model_patcher.py => layer_patcher.py} (100%) rename tests/backend/patches/{test_lora_patcher.py => test_layer_patcher.py} (99%) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 2e8cf961076..92d7f4638c0 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -20,8 +20,8 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.ti_utils import generate_ti_list from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( BasicConditioningInfo, ConditioningFieldData, diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index 17caae8ba82..5aeeff57ad5 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -39,8 +39,8 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.model_manager import BaseModelType, ModelVariantType from invokeai.backend.model_patcher import ModelPatcher +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion import PipelineIntermediateState from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext, DenoiseInputs from invokeai.backend.stable_diffusion.diffusers_pipeline import ( diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index c4b84b95838..8eb90d8bc5d 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -48,9 +48,9 @@ ) from invokeai.backend.flux.text_conditioning import FluxTextConditioning from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index b1714629776..c3a752ab30d 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -18,9 +18,9 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 5ca59337886..9103dbbb41f 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -17,9 +17,9 @@ from invokeai.app.invocations.primitives import SD3ConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import ModelFormat +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 2d6ad1758ab..7c1442177f0 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -22,8 +22,8 @@ from invokeai.app.invocations.model import UNetField from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, PipelineIntermediateState from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import ( MultiDiffusionPipeline, diff --git a/invokeai/backend/patches/model_patcher.py b/invokeai/backend/patches/layer_patcher.py similarity index 100% rename from invokeai/backend/patches/model_patcher.py rename to invokeai/backend/patches/layer_patcher.py diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index 9e04f8e9412..b9d1d717687 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -5,8 +5,8 @@ from diffusers import UNet2DConditionModel +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase if TYPE_CHECKING: diff --git a/tests/backend/patches/test_lora_patcher.py b/tests/backend/patches/test_layer_patcher.py similarity index 99% rename from tests/backend/patches/test_lora_patcher.py rename to tests/backend/patches/test_layer_patcher.py index 14a6b1c7f02..4e281d3afdf 100644 --- a/tests/backend/patches/test_lora_patcher.py +++ b/tests/backend/patches/test_layer_patcher.py @@ -4,9 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.model_patcher import LayerPatcher from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper From 6d7314ac0a526c3e688df0adaf741f5ce63037d4 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 17 Dec 2024 18:33:36 +0000 Subject: [PATCH 07/31] Consolidate the LayerPatching patching modes into a single implementation. --- invokeai/app/invocations/flux_denoise.py | 40 ++- invokeai/backend/patches/layer_patcher.py | 194 ++---------- .../stable_diffusion/extensions/lora.py | 6 +- tests/backend/patches/test_layer_patcher.py | 282 ++++++++---------- 4 files changed, 172 insertions(+), 350 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 8eb90d8bc5d..d8bc8135bc7 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -304,37 +304,33 @@ def _run_diffusion( config = transformer_info.config assert config is not None - # Apply LoRA models to the transformer. - # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + # Determine if the model is quantized. + # If the model is quantized, then we need to apply the LoRA weights as sidecar layers. This results in + # slower inference than direct patching, but is agnostic to the quantization format. if config.format in [ModelFormat.Checkpoint]: - # The model is non-quantized, so we can apply the LoRA weights directly into the model. - exit_stack.enter_context( - LayerPatcher.apply_smart_model_patches( - model=transformer, - patches=self._lora_iterator(context), - prefix=FLUX_LORA_TRANSFORMER_PREFIX, - dtype=inference_dtype, - cached_weights=cached_weights, - ) - ) + model_is_quantized = False elif config.format in [ ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b, ModelFormat.GGUFQuantized, ]: - # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference, - # than directly patching the weights, but is agnostic to the quantization format. - exit_stack.enter_context( - LayerPatcher.apply_model_sidecar_patches( - model=transformer, - patches=self._lora_iterator(context), - prefix=FLUX_LORA_TRANSFORMER_PREFIX, - dtype=inference_dtype, - ) - ) + model_is_quantized = True else: raise ValueError(f"Unsupported model format: {config.format}") + # Apply LoRA models to the transformer. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + exit_stack.enter_context( + LayerPatcher.apply_smart_model_patches( + model=transformer, + patches=self._lora_iterator(context), + prefix=FLUX_LORA_TRANSFORMER_PREFIX, + dtype=inference_dtype, + cached_weights=cached_weights, + force_sidecar_patching=model_is_quantized, + ) + ) + # Prepare IP-Adapter extensions. pos_ip_adapter_extensions, neg_ip_adapter_extensions = self._prep_ip_adapter_extensions( pos_image_prompt_clip_embeds=pos_image_prompt_clip_embeds, diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index 0a7ed628b7f..d7f6bea166b 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -23,6 +23,8 @@ def apply_smart_model_patches( prefix: str, dtype: torch.dtype, cached_weights: Optional[Dict[str, torch.Tensor]] = None, + force_direct_patching: bool = False, + force_sidecar_patching: bool = False, ): """Apply 'smart' model patching that chooses whether to use direct patching or a sidecar wrapper for each module. @@ -34,7 +36,7 @@ def apply_smart_model_patches( original_modules: dict[str, torch.nn.Module] = {} try: for patch, patch_weight in patches: - LayerPatcher._apply_smart_model_patch( + LayerPatcher.apply_smart_model_patch( model=model, prefix=prefix, patch=patch, @@ -42,6 +44,8 @@ def apply_smart_model_patches( original_weights=original_weights, original_modules=original_modules, dtype=dtype, + force_direct_patching=force_direct_patching, + force_sidecar_patching=force_sidecar_patching, ) yield @@ -60,7 +64,7 @@ def apply_smart_model_patches( @staticmethod @torch.no_grad() - def _apply_smart_model_patch( + def apply_smart_model_patch( model: torch.nn.Module, prefix: str, patch: ModelPatchRaw, @@ -68,6 +72,8 @@ def _apply_smart_model_patch( original_weights: OriginalWeightsStorage, original_modules: dict[str, torch.nn.Module], dtype: torch.dtype, + force_direct_patching: bool, + force_sidecar_patching: bool, ): """Apply a single LoRA patch to a model using the 'smart' patching strategy that chooses whether to use direct patching or a sidecar wrapper for each module. @@ -94,15 +100,27 @@ def _apply_smart_model_patch( # Decide whether to use direct patching or a sidecar wrapper. # Direct patching is preferred, because it results in better runtime speed. # Reasons to use sidecar patching: + # - The module is quantized, so the caller passed force_sidecar_patching=True. # - The module is already wrapped in a BaseSidecarWrapper. - # - The module is quantized. # - The module is on the CPU (and we don't want to store a second full copy of the original weights on the # CPU, since this would double the RAM usage) # NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller - # and that the caller will use the 'apply_model_sidecar_patches' method if the layer is quantized. + # and that the caller will set force_sidecar_patching=True if the layer is quantized. # TODO(ryand): Handle the case where we are running without a GPU. Should we set a config flag that allows # forcing full patching even on the CPU? - if isinstance(module, BaseSidecarWrapper) or LayerPatcher._is_any_part_of_layer_on_cpu(module): + use_sidecar_patching = False + if force_direct_patching and force_sidecar_patching: + raise ValueError("Cannot force both direct and sidecar patching.") + elif force_direct_patching: + use_sidecar_patching = False + elif force_sidecar_patching: + use_sidecar_patching = True + elif isinstance(module, BaseSidecarWrapper): + use_sidecar_patching = True + elif LayerPatcher._is_any_part_of_layer_on_cpu(module): + use_sidecar_patching = True + + if use_sidecar_patching: LayerPatcher._apply_model_layer_wrapper_patch( model=model, module_to_patch=module, @@ -125,89 +143,6 @@ def _apply_smart_model_patch( def _is_any_part_of_layer_on_cpu(layer: torch.nn.Module) -> bool: return any(p.device.type == "cpu" for p in layer.parameters()) - @staticmethod - @torch.no_grad() - @contextmanager - def apply_model_patches( - model: torch.nn.Module, - patches: Iterable[Tuple[ModelPatchRaw, float]], - prefix: str, - cached_weights: Optional[Dict[str, torch.Tensor]] = None, - ): - """Apply one or more LoRA patches to a model within a context manager. - - Args: - model (torch.nn.Module): The model to patch. - patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and - associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory - all at once. - prefix (str): The keys in the patches will be filtered to only include weights with this prefix. - cached_weights (Optional[Dict[str, torch.Tensor]], optional): Read-only copy of the model's state dict in - CPU RAM, for efficient unpatching purposes. - """ - original_weights = OriginalWeightsStorage(cached_weights) - try: - for patch, patch_weight in patches: - LayerPatcher.apply_model_patch( - model=model, - prefix=prefix, - patch=patch, - patch_weight=patch_weight, - original_weights=original_weights, - ) - del patch - - yield - finally: - for param_key, weight in original_weights.get_changed_weights(): - cur_param = model.get_parameter(param_key) - cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True) - - @staticmethod - @torch.no_grad() - def apply_model_patch( - model: torch.nn.Module, - prefix: str, - patch: ModelPatchRaw, - patch_weight: float, - original_weights: OriginalWeightsStorage, - ): - """Apply a single LoRA patch to a model. - - Args: - model (torch.nn.Module): The model to patch. - prefix (str): A string prefix that precedes keys used in the LoRAs weight layers. - patch (LoRAModelRaw): The LoRA model to patch in. - patch_weight (float): The weight of the LoRA patch. - original_weights (OriginalWeightsStorage): Storage for the original weights of the model, for unpatching. - """ - if patch_weight == 0: - return - - # If the layer keys contain a dot, then they are not flattened, and can be directly used to access model - # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been - # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly - # without searching, but some legacy code still uses flattened keys. - layer_keys_are_flattened = "." not in next(iter(patch.layers.keys())) - - prefix_len = len(prefix) - - for layer_key, layer in patch.layers.items(): - if not layer_key.startswith(prefix): - continue - - module_key, module = LayerPatcher._get_submodule( - model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened - ) - - LayerPatcher._apply_model_layer_patch( - module_to_patch=module, - module_to_patch_key=module_key, - patch=layer, - patch_weight=patch_weight, - original_weights=original_weights, - ) - @staticmethod @torch.no_grad() def _apply_model_layer_patch( @@ -254,89 +189,6 @@ def _apply_model_layer_patch( patch.to(device=TorchDevice.CPU_DEVICE) - @staticmethod - @torch.no_grad() - @contextmanager - def apply_model_sidecar_patches( - model: torch.nn.Module, - patches: Iterable[Tuple[ModelPatchRaw, float]], - prefix: str, - dtype: torch.dtype, - ): - """Apply one or more LoRA sidecar patches to a model within a context manager. Sidecar patches incur some - overhead compared to normal LoRA patching, but they allow for LoRA layers to applied to base layers in any - quantization format. - - Args: - model (torch.nn.Module): The model to patch. - patches (Iterable[Tuple[LoRAModelRaw, float]]): An iterator that returns tuples of LoRA patches and - associated weights. An iterator is used so that the LoRA patches do not need to be loaded into memory - all at once. - prefix (str): The keys in the patches will be filtered to only include weights with this prefix. - dtype (torch.dtype): The compute dtype of the sidecar layers. This cannot easily be inferred from the model, - since the sidecar layers are typically applied on top of quantized layers whose weight dtype is - different from their compute dtype. - """ - original_modules: dict[str, torch.nn.Module] = {} - try: - for patch, patch_weight in patches: - LayerPatcher._apply_model_sidecar_patch( - model=model, - prefix=prefix, - patch=patch, - patch_weight=patch_weight, - original_modules=original_modules, - dtype=dtype, - ) - yield - finally: - # Restore original modules. - # Note: This logic assumes no nested modules in original_modules. - for module_key, orig_module in original_modules.items(): - module_parent_key, module_name = LayerPatcher._split_parent_key(module_key) - parent_module = model.get_submodule(module_parent_key) - LayerPatcher._set_submodule(parent_module, module_name, orig_module) - - @staticmethod - def _apply_model_sidecar_patch( - model: torch.nn.Module, - patch: ModelPatchRaw, - patch_weight: float, - prefix: str, - original_modules: dict[str, torch.nn.Module], - dtype: torch.dtype, - ): - """Apply a single LoRA sidecar patch to a model.""" - - if patch_weight == 0: - return - - # If the layer keys contain a dot, then they are not flattened, and can be directly used to access model - # submodules. If the layer keys do not contain a dot, then they are flattened, meaning that all '.' have been - # replaced with '_'. Non-flattened keys are preferred, because they allow submodules to be accessed directly - # without searching, but some legacy code still uses flattened keys. - layer_keys_are_flattened = "." not in next(iter(patch.layers.keys())) - - prefix_len = len(prefix) - - for layer_key, layer in patch.layers.items(): - if not layer_key.startswith(prefix): - continue - - module_key, module = LayerPatcher._get_submodule( - model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened - ) - - LayerPatcher._apply_model_layer_wrapper_patch( - model=model, - module_to_patch=module, - module_to_patch_key=module_key, - patch=layer, - patch_weight=patch_weight, - original_modules=original_modules, - dtype=dtype, - ) - @staticmethod @torch.no_grad() def _apply_model_layer_wrapper_patch( diff --git a/invokeai/backend/stable_diffusion/extensions/lora.py b/invokeai/backend/stable_diffusion/extensions/lora.py index b9d1d717687..43986fad4d6 100644 --- a/invokeai/backend/stable_diffusion/extensions/lora.py +++ b/invokeai/backend/stable_diffusion/extensions/lora.py @@ -31,12 +31,16 @@ def __init__( def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage): lora_model = self._node_context.models.load(self._model_id).model assert isinstance(lora_model, ModelPatchRaw) - LayerPatcher.apply_model_patch( + LayerPatcher.apply_smart_model_patch( model=unet, prefix="lora_unet_", patch=lora_model, patch_weight=self._weight, original_weights=original_weights, + original_modules={}, + dtype=unet.dtype, + force_direct_patching=True, + force_sidecar_patching=False, ) del lora_model diff --git a/tests/backend/patches/test_layer_patcher.py b/tests/backend/patches/test_layer_patcher.py index 4e281d3afdf..06d64c05c27 100644 --- a/tests/backend/patches/test_layer_patcher.py +++ b/tests/backend/patches/test_layer_patcher.py @@ -30,160 +30,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @pytest.mark.parametrize( - ["device", "num_loras"], + "device", [ - ("cpu", 1), - pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ("cpu", 2), - pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), + "cpu", + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), ], ) -@torch.no_grad() -def test_apply_lora_patches(device: str, num_loras: int): - """Test the basic behavior of ModelPatcher.apply_lora_patches(...). Check that patching and unpatching produce the - correct result, and that model/LoRA tensors are moved between devices as expected. - """ - - linear_in_features = 4 - linear_out_features = 8 - lora_rank = 2 - model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=torch.float16) - - # Initialize num_loras LoRA models with weights of 0.5. - lora_weight = 0.5 - lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_loras): - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - lora_models.append((lora, lora_weight)) - - orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() - expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras) - - with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): - # After patching, all LoRA layer weights should have been moved back to the cpu. - for lora, _ in lora_models: - assert lora.layers["linear_layer_1"].up.device.type == "cpu" - assert lora.layers["linear_layer_1"].down.device.type == "cpu" - - # After patching, the patched model should still be on its original device. - assert model.linear_layer_1.weight.data.device.type == device - - torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight) - - # After unpatching, the original model weights should have been restored on the original device. - assert model.linear_layer_1.weight.data.device.type == device - torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device") -@torch.no_grad() -def test_apply_lora_patches_change_device(): - """Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching - still behaves correctly. - """ - linear_in_features = 4 - linear_out_features = 8 - lora_dim = 2 - # Initialize the model on the CPU. - model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) - - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - - orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() - - with LayerPatcher.apply_model_patches(model=model, patches=[(lora, 0.5)], prefix=""): - # After patching, all LoRA layer weights should have been moved back to the cpu. - assert lora_layers["linear_layer_1"].up.device.type == "cpu" - assert lora_layers["linear_layer_1"].down.device.type == "cpu" - - # After patching, the patched model should still be on the CPU. - assert model.linear_layer_1.weight.data.device.type == "cpu" - - # Move the model to the GPU. - assert model.to("cuda") - - # After unpatching, the original model weights should have been restored on the GPU. - assert model.linear_layer_1.weight.data.device.type == "cuda" - torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False) - - -@pytest.mark.parametrize( - ["device", "num_loras"], - [ - ("cpu", 1), - pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ("cpu", 2), - pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ], -) -def test_apply_lora_sidecar_patches(device: str, num_loras: int): - """Test the basic behavior of ModelPatcher.apply_lora_sidecar_patches(...). Check that unpatching works correctly.""" - dtype = torch.float16 - linear_in_features = 4 - linear_out_features = 8 - lora_rank = 2 - model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) - - # Initialize num_loras LoRA models with weights of 0.5. - lora_weight = 0.5 - lora_models: list[tuple[ModelPatchRaw, float]] = [] - for _ in range(num_loras): - lora_layers = { - "linear_layer_1": LoRALayer.from_state_dict_values( - values={ - "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), - "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), - }, - ) - } - lora = ModelPatchRaw(lora_layers) - lora_models.append((lora, lora_weight)) - - # Run inference before patching the model. - input = torch.randn(1, linear_in_features, device=device, dtype=dtype) - output_before_patch = model(input) - - # Patch the model and run inference during the patch. - with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): - output_during_patch = model(input) - - # Run inference after unpatching. - output_after_patch = model(input) - - # Check that the output before patching is different from the output during patching. - assert not torch.allclose(output_before_patch, output_during_patch) - - # Check that the output before patching is the same as the output after patching. - assert torch.allclose(output_before_patch, output_after_patch) - - +@pytest.mark.parametrize("num_loras", [1, 2]) @pytest.mark.parametrize( - ["device", "num_loras"], - [ - ("cpu", 1), - pytest.param("cuda", 1, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ("cpu", 2), - pytest.param("cuda", 2, marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")), - ], + ["force_sidecar_patching", "force_direct_patching"], [(True, False), (False, True), (False, False)] ) @torch.no_grad() -def test_apply_smart_model_patches(device: str, num_loras: int): +def test_apply_smart_model_patches( + device: str, num_loras: int, force_sidecar_patching: bool, force_direct_patching: bool +): """Test the basic behavior of ModelPatcher.apply_smart_model_patches(...). Check that unpatching works correctly.""" dtype = torch.float16 linear_in_features = 4 @@ -206,12 +66,44 @@ def test_apply_smart_model_patches(device: str, num_loras: int): lora = ModelPatchRaw(lora_layers) lora_models.append((lora, lora_weight)) + orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() + expected_patched_linear_weight = orig_linear_weight + (lora_rank * lora_weight * num_loras) + # Run inference before patching the model. input = torch.randn(1, linear_in_features, device=device, dtype=dtype) output_before_patch = model(input) + expect_sidecar_wrappers = device == "cpu" + if force_sidecar_patching: + expect_sidecar_wrappers = True + elif force_direct_patching: + expect_sidecar_wrappers = False + # Patch the model and run inference during the patch. - with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype): + with LayerPatcher.apply_smart_model_patches( + model=model, + patches=lora_models, + prefix="", + dtype=dtype, + force_direct_patching=force_direct_patching, + force_sidecar_patching=force_sidecar_patching, + ): + if expect_sidecar_wrappers: + # There should be sidecar wrappers in the model. + assert isinstance(model.linear_layer_1, BaseSidecarWrapper) + else: + # There should be no sidecar wrappers in the model. + assert not isinstance(model.linear_layer_1, BaseSidecarWrapper) + torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight) + + # After patching, the patched model should still be on its original device. + assert model.linear_layer_1.weight.data.device.type == device + + # After patching, all LoRA layer weights should have been moved back to the cpu. + for lora, _ in lora_models: + assert lora.layers["linear_layer_1"].up.device.type == "cpu" + assert lora.layers["linear_layer_1"].down.device.type == "cpu" + output_during_patch = model(input) # Run inference after unpatching. @@ -320,16 +212,94 @@ def test_all_patching_methods_produce_same_output(num_loras: int): input = torch.randn(1, linear_in_features, device="cpu", dtype=dtype) - with LayerPatcher.apply_model_patches(model=model, patches=lora_models, prefix=""): - output_lora_patches = model(input) + with LayerPatcher.apply_smart_model_patches( + model=model, patches=lora_models, prefix="", dtype=dtype, force_direct_patching=True + ): + output_force_direct = model(input) - with LayerPatcher.apply_model_sidecar_patches(model=model, patches=lora_models, prefix="", dtype=dtype): - output_lora_sidecar_patches = model(input) + with LayerPatcher.apply_smart_model_patches( + model=model, patches=lora_models, prefix="", dtype=dtype, force_sidecar_patching=True + ): + output_force_sidecar = model(input) with LayerPatcher.apply_smart_model_patches(model=model, patches=lora_models, prefix="", dtype=dtype): - output_smart_lora_patches = model(input) + output_smart = model(input) # Note: We set atol=1e-5 because the test failed occasionally with the default atol=1e-8. Slight numerical # differences are tolerable and expected due to the difference between sidecar vs. patching. - assert torch.allclose(output_lora_patches, output_lora_sidecar_patches, atol=1e-5) - assert torch.allclose(output_lora_patches, output_smart_lora_patches, atol=1e-5) + assert torch.allclose(output_force_direct, output_force_sidecar, atol=1e-5) + assert torch.allclose(output_force_direct, output_smart, atol=1e-5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device") +@torch.no_grad() +def test_apply_smart_model_patches_change_device(): + """Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching + still behaves correctly. + """ + linear_in_features = 4 + linear_out_features = 8 + lora_dim = 2 + # Initialize the model on the CPU. + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + + orig_linear_weight = model.linear_layer_1.weight.data.detach().clone() + + with LayerPatcher.apply_smart_model_patches( + model=model, patches=[(lora, 0.5)], prefix="", dtype=torch.float16, force_direct_patching=True + ): + # After patching, all LoRA layer weights should have been moved back to the cpu. + assert lora_layers["linear_layer_1"].up.device.type == "cpu" + assert lora_layers["linear_layer_1"].down.device.type == "cpu" + + # After patching, the patched model should still be on the CPU. + assert model.linear_layer_1.weight.data.device.type == "cpu" + + # There should be no sidecar wrappers in the model. + assert not isinstance(model.linear_layer_1, BaseSidecarWrapper) + + # Move the model to the GPU. + assert model.to("cuda") + + # After unpatching, the original model weights should have been restored on the GPU. + assert model.linear_layer_1.weight.data.device.type == "cuda" + torch.testing.assert_close(model.linear_layer_1.weight.data, orig_linear_weight, check_device=False) + + +def test_apply_smart_model_patches_force_sidecar_and_direct_patching(): + """Test that ModelPatcher.apply_smart_model_patches(..., force_direct_patching=True, force_sidecar_patching=True) + raises an error. + """ + linear_in_features = 4 + linear_out_features = 8 + lora_rank = 2 + model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + lora_layers = { + "linear_layer_1": LoRALayer.from_state_dict_values( + values={ + "lora_down.weight": torch.ones((lora_rank, linear_in_features), device="cpu", dtype=torch.float16), + "lora_up.weight": torch.ones((linear_out_features, lora_rank), device="cpu", dtype=torch.float16), + }, + ) + } + lora = ModelPatchRaw(lora_layers) + with pytest.raises(ValueError, match="Cannot force both direct and sidecar patching."): + with LayerPatcher.apply_smart_model_patches( + model=model, + patches=[(lora, 0.5)], + prefix="", + dtype=torch.float16, + force_direct_patching=True, + force_sidecar_patching=True, + ): + pass From 987c9ae0760d2ffc5809e38b8dfea66e7e0e9022 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 24 Dec 2024 22:21:31 +0000 Subject: [PATCH 08/31] Move custom autocast modules to separate files in a custom_modules/ directory. --- .../torch_module_autocast/autocast_modules.py | 50 ------------------- .../custom_modules/README.md | 8 +++ .../custom_modules/__init__.py | 0 .../custom_modules/custom_conv1d.py | 10 ++++ .../custom_modules/custom_conv2d.py | 10 ++++ .../custom_modules/custom_embedding.py | 17 +++++++ .../custom_modules/custom_group_norm.py | 10 ++++ .../custom_invoke_linear_8_bit_lt.py | 0 .../custom_invoke_linear_nf4.py | 0 .../custom_modules/custom_linear.py | 10 ++++ .../torch_module_autocast.py | 14 ++++-- .../test_cached_model_with_partial_load.py | 4 +- .../test_autocast_modules.py | 4 +- 13 files changed, 81 insertions(+), 56 deletions(-) delete mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/__init__.py create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py rename invokeai/backend/model_manager/load/model_cache/torch_module_autocast/{ => custom_modules}/custom_invoke_linear_8_bit_lt.py (100%) rename invokeai/backend/model_manager/load/model_cache/torch_module_autocast/{ => custom_modules}/custom_invoke_linear_nf4.py (100%) create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py deleted file mode 100644 index 8a1bacf6833..00000000000 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/autocast_modules.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch - -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device - -# This file contains custom torch.nn.Module classes that support streaming of weights to the target device. -# Each class sub-classes the original module type that is is replacing, so the following properties are preserved: -# - isinstance(m, torch.nn.OrginalModule) should still work. -# - Patching the weights (e.g. for LoRA) should still work if non-quantized. - - -class CustomLinear(torch.nn.Linear): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return torch.nn.functional.linear(input, weight, bias) - - -class CustomConv1d(torch.nn.Conv1d): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return self._conv_forward(input, weight, bias) - - -class CustomConv2d(torch.nn.Conv2d): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return self._conv_forward(input, weight, bias) - - -class CustomGroupNorm(torch.nn.GroupNorm): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - bias = cast_to_device(self.bias, input.device) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) - - -class CustomEmbedding(torch.nn.Embedding): - def forward(self, input: torch.Tensor) -> torch.Tensor: - weight = cast_to_device(self.weight, input.device) - return torch.nn.functional.embedding( - input, - weight, - self.padding_idx, - self.max_norm, - self.norm_type, - self.scale_grad_by_freq, - self.sparse, - ) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md new file mode 100644 index 00000000000..cadb1b6dd5a --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/README.md @@ -0,0 +1,8 @@ + +This directory contains custom implementations of common torch.nn.Module classes that add support for: +- Streaming weights to the execution device +- Applying sidecar patches at execution time (e.g. sidecar LoRA layers) + +Each custom class sub-classes the original module type that is is replacing, so the following properties are preserved: +- `isinstance(m, torch.nn.OrginalModule)` should still work. +- Patching the weights directly (e.g. for LoRA) should still work. (Of course, this is not possible for quantized layers, hence the sidecar support.) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/__init__.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py new file mode 100644 index 00000000000..078930c5b5d --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -0,0 +1,10 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device + + +class CustomConv1d(torch.nn.Conv1d): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py new file mode 100644 index 00000000000..99b7137184f --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -0,0 +1,10 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device + + +class CustomConv2d(torch.nn.Conv2d): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return self._conv_forward(input, weight, bias) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py new file mode 100644 index 00000000000..8e6821aee63 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py @@ -0,0 +1,17 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device + + +class CustomEmbedding(torch.nn.Embedding): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + return torch.nn.functional.embedding( + input, + weight, + self.padding_idx, + self.max_norm, + self.norm_type, + self.scale_grad_by_freq, + self.sparse, + ) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py new file mode 100644 index 00000000000..8b8fff2508c --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py @@ -0,0 +1,10 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device + + +class CustomGroupNorm(torch.nn.GroupNorm): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py similarity index 100% rename from invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_8_bit_lt.py rename to invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py similarity index 100% rename from invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py rename to invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py new file mode 100644 index 00000000000..4b3dc1a498e --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -0,0 +1,10 @@ +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device + + +class CustomLinear(torch.nn.Linear): + def forward(self, input: torch.Tensor) -> torch.Tensor: + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + return torch.nn.functional.linear(input, weight, bias) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 825eebf64e8..7bc05885d4e 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -1,10 +1,18 @@ import torch -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import ( +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import ( CustomConv1d, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv2d import ( CustomConv2d, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import ( CustomEmbedding, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import ( CustomGroupNorm, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( CustomLinear, ) @@ -18,10 +26,10 @@ try: # These dependencies are not expected to be present on MacOS. - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import ( CustomInvokeLinear8bitLt, ) - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( CustomInvokeLinearNF4, ) from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index e3c99d0c34f..00ce27d580f 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -5,7 +5,9 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + CustomLinear, +) from invokeai.backend.util.calc_tensor_size import calc_tensor_size from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py index 38fa467c602..ba5f27cdaaf 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py @@ -4,10 +4,10 @@ if not torch.cuda.is_available(): pytest.skip("CUDA is not available", allow_module_level=True) else: - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import ( + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import ( CustomInvokeLinear8bitLt, ) - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import ( + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( CustomInvokeLinearNF4, ) from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt From 03944191db049d80fff4ebf288117fe54409f662 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 24 Dec 2024 22:29:11 +0000 Subject: [PATCH 09/31] Split test_autocast_modules.py into separate test files to mirror the source file structure. --- .../test_custom_invoke_linear_8_bit_lt.py | 70 +++++++++++++++++ .../test_custom_invoke_linear_nf4.py} | 75 +------------------ 2 files changed, 74 insertions(+), 71 deletions(-) create mode 100644 tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py rename tests/backend/model_manager/load/model_cache/torch_module_autocast/{test_autocast_modules.py => custom_modules/test_custom_invoke_linear_nf4.py} (50%) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py new file mode 100644 index 00000000000..9f07363f248 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -0,0 +1,70 @@ +import pytest +import torch + +if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) +else: + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import ( + CustomInvokeLinear8bitLt, + ) + from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt + + +@pytest.fixture +def linear_8bit_lt_layer(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + torch.manual_seed(1) + + orig_layer = torch.nn.Linear(32, 64) + orig_layer_state_dict = orig_layer.state_dict() + + # Prepare a quantized InvokeLinear8bitLt layer. + quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer.load_state_dict(orig_layer_state_dict) + quantized_layer.to("cuda") + + # Assert that the InvokeLinear8bitLt layer is quantized. + assert quantized_layer.weight.CB is not None + assert quantized_layer.weight.SCB is not None + assert quantized_layer.weight.CB.dtype == torch.int8 + + return quantized_layer + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" + # Run inference on the original layer. + x = torch.randn(1, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + y_custom = linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) + + +def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt): + """Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU).""" + # Run inference on the original layer. + x = torch.randn(1, 32).to("cuda") + y_quantized = linear_8bit_lt_layer(x) + + # Copy the state dict to the CPU and reload it. + state_dict = linear_8bit_lt_layer.state_dict() + state_dict = {k: v.to("cpu") for k, v in state_dict.items()} + linear_8bit_lt_layer.load_state_dict(state_dict) + + # Inference of the original layer should fail. + with pytest.raises(RuntimeError): + linear_8bit_lt_layer(x) + + # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. + linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + y_custom = linear_8bit_lt_layer(x) + + # Assert that the quantized and custom layers produce the same output. + assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py similarity index 50% rename from tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py rename to tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index ba5f27cdaaf..1b74a9d6561 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -1,77 +1,10 @@ import pytest import torch -if not torch.cuda.is_available(): - pytest.skip("CUDA is not available", allow_module_level=True) -else: - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_8_bit_lt import ( - CustomInvokeLinear8bitLt, - ) - from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( - CustomInvokeLinearNF4, - ) - from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt - from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 - - -@pytest.fixture -def linear_8bit_lt_layer(): - if not torch.cuda.is_available(): - pytest.skip("CUDA is not available") - - torch.manual_seed(1) - - orig_layer = torch.nn.Linear(32, 64) - orig_layer_state_dict = orig_layer.state_dict() - - # Prepare a quantized InvokeLinear8bitLt layer. - quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) - quantized_layer.load_state_dict(orig_layer_state_dict) - quantized_layer.to("cuda") - - # Assert that the InvokeLinear8bitLt layer is quantized. - assert quantized_layer.weight.CB is not None - assert quantized_layer.weight.SCB is not None - assert quantized_layer.weight.CB.dtype == torch.int8 - - return quantized_layer - - -def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): - """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" - # Run inference on the original layer. - x = torch.randn(1, 32).to("cuda") - y_quantized = linear_8bit_lt_layer(x) - - # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. - linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt - y_custom = linear_8bit_lt_layer(x) - - # Assert that the quantized and custom layers produce the same output. - assert torch.allclose(y_quantized, y_custom, atol=1e-5) - - -def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: InvokeLinear8bitLt): - """Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU).""" - # Run inference on the original layer. - x = torch.randn(1, 32).to("cuda") - y_quantized = linear_8bit_lt_layer(x) - - # Copy the state dict to the CPU and reload it. - state_dict = linear_8bit_lt_layer.state_dict() - state_dict = {k: v.to("cpu") for k, v in state_dict.items()} - linear_8bit_lt_layer.load_state_dict(state_dict) - - # Inference of the original layer should fail. - with pytest.raises(RuntimeError): - linear_8bit_lt_layer(x) - - # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. - linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt - y_custom = linear_8bit_lt_layer(x) - - # Assert that the quantized and custom layers produce the same output. - assert torch.allclose(y_quantized, y_custom, atol=1e-5) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, +) +from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 @pytest.fixture From a8b2c4c3d2c8ab6be46815e6fef95f2e593da365 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Dec 2024 18:33:46 +0000 Subject: [PATCH 10/31] Add inference tests for all custom module types (i.e. to test autocasting from cpu to device). --- .../backend/quantization/gguf/ggml_tensor.py | 2 + .../custom_modules/test_all_custom_modules.py | 186 ++++++++++++++++++ .../test_custom_invoke_linear_8_bit_lt.py | 8 +- .../test_custom_invoke_linear_nf4.py | 8 +- 4 files changed, 200 insertions(+), 4 deletions(-) create mode 100644 tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py diff --git a/invokeai/backend/quantization/gguf/ggml_tensor.py b/invokeai/backend/quantization/gguf/ggml_tensor.py index a9f5d68b76c..62be2bdb637 100644 --- a/invokeai/backend/quantization/gguf/ggml_tensor.py +++ b/invokeai/backend/quantization/gguf/ggml_tensor.py @@ -48,11 +48,13 @@ def apply_to_quantized_tensor(func, args, kwargs): # Ops to run on the quantized tensor. torch.ops.aten.detach.default: apply_to_quantized_tensor, # pyright: ignore torch.ops.aten._to_copy.default: apply_to_quantized_tensor, # pyright: ignore + torch.ops.aten.clone.default: apply_to_quantized_tensor, # pyright: ignore # Ops to run on dequantized tensors. torch.ops.aten.t.default: dequantize_and_run, # pyright: ignore torch.ops.aten.addmm.default: dequantize_and_run, # pyright: ignore torch.ops.aten.mul.Tensor: dequantize_and_run, # pyright: ignore torch.ops.aten.add.Tensor: dequantize_and_run, # pyright: ignore + torch.ops.aten.allclose.default: dequantize_and_run, # pyright: ignore } if torch.backends.mps.is_available(): diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py new file mode 100644 index 00000000000..2014f1896ad --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -0,0 +1,186 @@ +import copy + +import gguf +import pytest +import torch + +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) +from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( + build_linear_8bit_lt_layer, +) +from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_nf4 import ( + build_linear_nf4_layer, +) +from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor + + +def build_linear_layer_with_ggml_quantized_tensor(): + layer = torch.nn.Linear(32, 64) + ggml_quantized_weight = quantize_tensor(layer.weight, gguf.GGMLQuantizationType.Q8_0) + layer.weight = torch.nn.Parameter(ggml_quantized_weight) + ggml_quantized_bias = quantize_tensor(layer.bias, gguf.GGMLQuantizationType.Q8_0) + layer.bias = torch.nn.Parameter(ggml_quantized_bias) + return layer + + +parameterize_all_devices = pytest.mark.parametrize( + ("device"), + [ + pytest.param("cpu"), + pytest.param( + "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") + ), + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), + ], +) + +parameterize_cuda_and_mps = pytest.mark.parametrize( + ("device"), + [ + pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")), + pytest.param( + "mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.") + ), + ], +) + +parameterize_all_layer_types = pytest.mark.parametrize( + ("orig_layer", "layer_input", "supports_cpu_inference"), + [ + (torch.nn.Linear(8, 16), torch.randn(1, 8), True), + (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True), + (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True), + (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True), + (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True), + (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True), + (build_linear_8bit_lt_layer(), torch.randn(1, 32), False), + (build_linear_nf4_layer(), torch.randn(1, 64), False), + ], +) + + +def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str): + """A helper function to move a layer to a device by roundtripping through a state dict. This most closely matches + how models are moved in the app. Some of the quantization types have broken semantics around calling .to() on the + layer directly, so this is a workaround. + + We should fix this in the future. + Relevant article: https://pytorch.org/tutorials/recipes/recipes/swap_tensors.html + """ + state_dict = layer.state_dict() + state_dict = {k: v.to(device) for k, v in state_dict.items()} + layer.load_state_dict(state_dict, assign=True) + + +@parameterize_all_devices +@parameterize_all_layer_types +def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool): + """Test that .state_dict() behaves the same on the original layer and the wrapped layer.""" + # Get the original layer on the test device. + orig_layer.to(device) + orig_state_dict = orig_layer.state_dict() + + # Wrap the original layer. + custom_layer = copy.deepcopy(orig_layer) + apply_custom_layers_to_model(custom_layer) + + custom_state_dict = custom_layer.state_dict() + + assert set(orig_state_dict.keys()) == set(custom_state_dict.keys()) + for k in orig_state_dict: + assert orig_state_dict[k].shape == custom_state_dict[k].shape + assert orig_state_dict[k].dtype == custom_state_dict[k].dtype + assert orig_state_dict[k].device == custom_state_dict[k].device + assert torch.allclose(orig_state_dict[k], custom_state_dict[k]) + + +@parameterize_all_devices +@parameterize_all_layer_types +def test_load_state_dict( + device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool +): + """Test that .load_state_dict() behaves the same on the original layer and the wrapped layer.""" + orig_layer.to(device) + + custom_layer = copy.deepcopy(orig_layer) + apply_custom_layers_to_model(custom_layer) + + # Do a state dict roundtrip. + orig_state_dict = orig_layer.state_dict() + custom_state_dict = custom_layer.state_dict() + + orig_layer.load_state_dict(custom_state_dict, assign=True) + custom_layer.load_state_dict(orig_state_dict, assign=True) + + orig_state_dict = orig_layer.state_dict() + custom_state_dict = custom_layer.state_dict() + + # Assert that the state dicts are the same after the roundtrip. + assert set(orig_state_dict.keys()) == set(custom_state_dict.keys()) + for k in orig_state_dict: + assert orig_state_dict[k].shape == custom_state_dict[k].shape + assert orig_state_dict[k].dtype == custom_state_dict[k].dtype + assert orig_state_dict[k].device == custom_state_dict[k].device + assert torch.allclose(orig_state_dict[k], custom_state_dict[k]) + + +@parameterize_all_devices +@parameterize_all_layer_types +def test_inference_on_device( + device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool +): + """Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the + device. + """ + if device == "cpu" and not supports_cpu_inference: + pytest.skip("Layer does not support CPU inference.") + + layer_to_device_via_state_dict(orig_layer, device) + + custom_layer = copy.deepcopy(orig_layer) + apply_custom_layers_to_model(custom_layer) + + # Run inference with the original layer. + x = layer_input.to(device) + orig_output = orig_layer(x) + + # Run inference with the wrapped layer. + custom_output = custom_layer(x) + + assert torch.allclose(orig_output, custom_output) + + +@parameterize_cuda_and_mps +@parameterize_all_layer_types +def test_inference_autocast_from_cpu_to_device( + device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool +): + """Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the + device. + """ + # Make sure the original layer is on the device. + layer_to_device_via_state_dict(orig_layer, device) + + x = layer_input.to(device) + + # Run inference with the original layer on the device. + orig_output = orig_layer(x) + + # Move the original layer to the CPU. + layer_to_device_via_state_dict(orig_layer, "cpu") + + # Inference should fail with an input on the device. + with pytest.raises(RuntimeError): + _ = orig_layer(x) + + # Wrap the original layer. + custom_layer = copy.deepcopy(orig_layer) + apply_custom_layers_to_model(custom_layer) + + # Run inference with the wrapped layer on the device. + custom_output = custom_layer(x) + assert custom_output.device.type == device + + assert torch.allclose(orig_output, custom_output) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py index 9f07363f248..f6ff7a6e2e2 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -10,8 +10,7 @@ from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -@pytest.fixture -def linear_8bit_lt_layer(): +def build_linear_8bit_lt_layer(): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -33,6 +32,11 @@ def linear_8bit_lt_layer(): return quantized_layer +@pytest.fixture +def linear_8bit_lt_layer(): + return build_linear_8bit_lt_layer() + + def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: InvokeLinear8bitLt): """Test CustomInvokeLinear8bitLt inference with all weights on the GPU.""" # Run inference on the original layer. diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index 1b74a9d6561..4cfe5f6aaf7 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -7,8 +7,7 @@ from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 -@pytest.fixture -def linear_nf4_layer(): +def build_linear_nf4_layer(): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") @@ -28,6 +27,11 @@ def linear_nf4_layer(): return quantized_layer +@pytest.fixture +def linear_nf4_layer(): + return build_linear_nf4_layer() + + def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4): """Test CustomInvokeLinearNF4 inference with all weights on the GPU.""" # Run inference on the original layer. From b0b699a01fdc23eff6332bf5d4b5125dbff67a26 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Dec 2024 18:45:56 +0000 Subject: [PATCH 11/31] Add unit test to test that isinstance(...) behaves as expected with custom module types. --- .../custom_modules/test_all_custom_modules.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 2014f1896ad..66aa356d4ac 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -74,6 +74,17 @@ def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str): layer.load_state_dict(state_dict, assign=True) +@parameterize_all_layer_types +def test_isinstance(orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool): + """Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer.""" + orig_type = type(orig_layer) + + apply_custom_layers_to_model(orig_layer) + + assert isinstance(orig_layer, orig_type) + assert type(orig_layer) is not orig_type + + @parameterize_all_devices @parameterize_all_layer_types def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool): From 9692a36dd68f3a1363ec65c8c3465600e780065b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Dec 2024 19:41:25 +0000 Subject: [PATCH 12/31] Use a fixture to parameterize tests in test_all_custom_modules.py so that a fresh instance of the layer under test is initialized for each test. --- .../custom_modules/test_all_custom_modules.py | 81 ++++++++++++------- 1 file changed, 53 insertions(+), 28 deletions(-) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 66aa356d4ac..ef53a85c99a 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -46,19 +46,43 @@ def build_linear_layer_with_ggml_quantized_tensor(): ], ) -parameterize_all_layer_types = pytest.mark.parametrize( - ("orig_layer", "layer_input", "supports_cpu_inference"), - [ - (torch.nn.Linear(8, 16), torch.randn(1, 8), True), - (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True), - (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True), - (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True), - (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True), - (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True), - (build_linear_8bit_lt_layer(), torch.randn(1, 32), False), - (build_linear_nf4_layer(), torch.randn(1, 64), False), - ], + +LayerUnderTest = tuple[torch.nn.Module, torch.Tensor, bool] + + +@pytest.fixture( + params=[ + "linear", + "conv1d", + "conv2d", + "group_norm", + "embedding", + "linear_with_ggml_quantized_tensor", + "invoke_linear_8_bit_lt", + "invoke_linear_nf4", + ] ) +def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest: + """A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test.""" + layer_type = request.param + if layer_type == "linear": + return (torch.nn.Linear(8, 16), torch.randn(1, 8), True) + elif layer_type == "conv1d": + return (torch.nn.Conv1d(8, 16, 3), torch.randn(1, 8, 5), True) + elif layer_type == "conv2d": + return (torch.nn.Conv2d(8, 16, 3), torch.randn(1, 8, 5, 5), True) + elif layer_type == "group_norm": + return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True) + elif layer_type == "embedding": + return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True) + elif layer_type == "linear_with_ggml_quantized_tensor": + return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True) + elif layer_type == "invoke_linear_8_bit_lt": + return (build_linear_8bit_lt_layer(), torch.randn(1, 32), False) + elif layer_type == "invoke_linear_nf4": + return (build_linear_nf4_layer(), torch.randn(1, 64), False) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str): @@ -74,9 +98,9 @@ def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str): layer.load_state_dict(state_dict, assign=True) -@parameterize_all_layer_types -def test_isinstance(orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool): +def test_isinstance(layer_under_test: LayerUnderTest): """Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer.""" + orig_layer, _, _ = layer_under_test orig_type = type(orig_layer) apply_custom_layers_to_model(orig_layer) @@ -86,9 +110,10 @@ def test_isinstance(orig_layer: torch.nn.Module, layer_input: torch.Tensor, supp @parameterize_all_devices -@parameterize_all_layer_types -def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool): +def test_state_dict(device: str, layer_under_test: LayerUnderTest): """Test that .state_dict() behaves the same on the original layer and the wrapped layer.""" + orig_layer, _, _ = layer_under_test + # Get the original layer on the test device. orig_layer.to(device) orig_state_dict = orig_layer.state_dict() @@ -108,11 +133,10 @@ def test_state_dict(device: str, orig_layer: torch.nn.Module, layer_input: torch @parameterize_all_devices -@parameterize_all_layer_types -def test_load_state_dict( - device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool -): +def test_load_state_dict(device: str, layer_under_test: LayerUnderTest): """Test that .load_state_dict() behaves the same on the original layer and the wrapped layer.""" + orig_layer, _, _ = layer_under_test + orig_layer.to(device) custom_layer = copy.deepcopy(orig_layer) @@ -138,13 +162,12 @@ def test_load_state_dict( @parameterize_all_devices -@parameterize_all_layer_types -def test_inference_on_device( - device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool -): +def test_inference_on_device(device: str, layer_under_test: LayerUnderTest): """Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the device. """ + orig_layer, layer_input, supports_cpu_inference = layer_under_test + if device == "cpu" and not supports_cpu_inference: pytest.skip("Layer does not support CPU inference.") @@ -164,13 +187,15 @@ def test_inference_on_device( @parameterize_cuda_and_mps -@parameterize_all_layer_types -def test_inference_autocast_from_cpu_to_device( - device: str, orig_layer: torch.nn.Module, layer_input: torch.Tensor, supports_cpu_inference: bool -): +def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: LayerUnderTest): """Test that inference behaves the same on the original layer and the wrapped layer when all weights are on the device. """ + orig_layer, layer_input, supports_cpu_inference = layer_under_test + + if device == "cpu" and not supports_cpu_inference: + pytest.skip("Layer does not support CPU inference.") + # Make sure the original layer is on the device. layer_to_device_via_state_dict(orig_layer, device) From 7d6ab0ceb223e4208cfedc09094d6a112c04f035 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Dec 2024 20:08:30 +0000 Subject: [PATCH 13/31] Add a CustomModuleMixin class with a flag for enabling/disabling autocasting (since it incurs some runtime speed overhead.) --- .../custom_modules/custom_conv1d.py | 13 +++++++++++-- .../custom_modules/custom_conv2d.py | 13 +++++++++++-- .../custom_modules/custom_embedding.py | 13 +++++++++++-- .../custom_modules/custom_group_norm.py | 13 +++++++++++-- .../custom_modules/custom_invoke_linear_8_bit_lt.py | 13 +++++++++++-- .../custom_modules/custom_invoke_linear_nf4.py | 13 +++++++++++-- .../custom_modules/custom_linear.py | 13 +++++++++++-- .../custom_modules/custom_module_mixin.py | 11 +++++++++++ .../torch_module_autocast/torch_module_autocast.py | 2 ++ .../custom_modules/test_all_custom_modules.py | 6 ++++++ .../test_custom_invoke_linear_8_bit_lt.py | 1 + .../custom_modules/test_custom_invoke_linear_nf4.py | 2 ++ 12 files changed, 99 insertions(+), 14 deletions(-) create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py index 078930c5b5d..e0e38edd53c 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -1,10 +1,19 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) -class CustomConv1d(torch.nn.Conv1d): - def forward(self, input: torch.Tensor) -> torch.Tensor: +class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return self._conv_forward(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index 99b7137184f..e8981e34c58 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -1,10 +1,19 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) -class CustomConv2d(torch.nn.Conv2d): - def forward(self, input: torch.Tensor) -> torch.Tensor: +class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return self._conv_forward(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py index 8e6821aee63..e6f0c5df21f 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py @@ -1,10 +1,13 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) -class CustomEmbedding(torch.nn.Embedding): - def forward(self, input: torch.Tensor) -> torch.Tensor: +class CustomEmbedding(torch.nn.Embedding, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) return torch.nn.functional.embedding( input, @@ -15,3 +18,9 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.scale_grad_by_freq, self.sparse, ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py index 8b8fff2508c..66a46ac7ea2 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py @@ -1,10 +1,19 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) -class CustomGroupNorm(torch.nn.GroupNorm): - def forward(self, input: torch.Tensor) -> torch.Tensor: +class CustomGroupNorm(torch.nn.GroupNorm, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py index 3941a2af6be..aa6acd31c51 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py @@ -2,11 +2,14 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -class CustomInvokeLinear8bitLt(InvokeLinear8bitLt): - def forward(self, x: torch.Tensor) -> torch.Tensor: +class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin): + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: matmul_state = bnb.MatmulLtState() matmul_state.threshold = self.state.threshold matmul_state.has_fp16_weights = self.state.has_fp16_weights @@ -25,3 +28,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be # on the wrong device. return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py index c697b3c7b43..60e987b3f36 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py @@ -4,11 +4,14 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 -class CustomInvokeLinearNF4(InvokeLinearNF4): - def forward(self, x: torch.Tensor) -> torch.Tensor: +class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin): + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) # weights are cast automatically as Int8Params, but the bias has to be cast manually @@ -43,3 +46,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bias = cast_to_device(self.bias, x.device) return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index 4b3dc1a498e..c1087c49079 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -1,10 +1,19 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) -class CustomLinear(torch.nn.Linear): - def forward(self, input: torch.Tensor) -> torch.Tensor: +class CustomLinear(torch.nn.Linear, CustomModuleMixin): + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return torch.nn.functional.linear(input, weight, bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + if self._device_autocasting_enabled: + return self._autocast_forward(input) + else: + return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py new file mode 100644 index 00000000000..18d1e507853 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -0,0 +1,11 @@ +class CustomModuleMixin: + """A mixin class for custom modules that enables device autocasting of module parameters.""" + + _device_autocasting_enabled = False + + def set_device_autocasting_enabled(self, enabled: bool): + """Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to + disable autocasting, which results in slightly faster execution speed when we know that device autocasting is + not needed. + """ + self._device_autocasting_enabled = enabled diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 7bc05885d4e..48f95a75b6a 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -46,6 +46,8 @@ def apply_custom_layers(module: torch.nn.Module): override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None) if override_type is not None: module.__class__ = override_type + # TODO(ryand): In the future, we should manage this flag on a per-module basis. + module.set_device_autocasting_enabled(True) # model.apply(...) calls apply_custom_layers(...) on each module in the model. model.apply(apply_custom_layers) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index ef53a85c99a..f8fccda2c11 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -215,7 +215,13 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La custom_layer = copy.deepcopy(orig_layer) apply_custom_layers_to_model(custom_layer) + # Inference should still fail with autocasting disabled. + custom_layer.set_device_autocasting_enabled(False) + with pytest.raises(RuntimeError): + _ = custom_layer(x) + # Run inference with the wrapped layer on the device. + custom_layer.set_device_autocasting_enabled(True) custom_output = custom_layer(x) assert custom_output.device.type == device diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py index f6ff7a6e2e2..8a2a0e1b614 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -68,6 +68,7 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: I # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt + linear_8bit_lt_layer.set_device_autocasting_enabled(True) y_custom = linear_8bit_lt_layer(x) # Assert that the quantized and custom layers produce the same output. diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index 4cfe5f6aaf7..154a75b8588 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -40,6 +40,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. linear_nf4_layer.__class__ = CustomInvokeLinearNF4 + linear_nf4_layer.set_device_autocasting_enabled(True) y_custom = linear_nf4_layer(x) # Assert that the quantized and custom layers produce the same output. @@ -66,6 +67,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLin # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. linear_nf4_layer.__class__ = CustomInvokeLinearNF4 + linear_nf4_layer.set_device_autocasting_enabled(True) y_custom = linear_nf4_layer(x) # Assert that the state dict (and the tensors that it references) are still on the CPU. From b06d61e3c03700e9ec503ba240554b64b9c56594 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Dec 2024 16:29:48 +0000 Subject: [PATCH 14/31] Improve custom layer wrap/unwrap logic. --- .../custom_modules/custom_module_mixin.py | 3 +- .../torch_module_autocast.py | 63 ++++++++++++++----- .../custom_modules/test_all_custom_modules.py | 49 ++++++++++++--- .../test_custom_invoke_linear_8_bit_lt.py | 14 +++-- .../test_custom_invoke_linear_nf4.py | 23 ++++--- 5 files changed, 111 insertions(+), 41 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 18d1e507853..75663317107 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -1,7 +1,8 @@ class CustomModuleMixin: """A mixin class for custom modules that enables device autocasting of module parameters.""" - _device_autocasting_enabled = False + def __init__(self): + self._device_autocasting_enabled = False def set_device_autocasting_enabled(self, enabled: bool): """Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 48f95a75b6a..b9fd58f464a 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -15,6 +15,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( CustomLinear, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = { torch.nn.Linear: CustomLinear, @@ -41,26 +44,52 @@ pass -def apply_custom_layers_to_model(model: torch.nn.Module): - def apply_custom_layers(module: torch.nn.Module): - override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(module), None) - if override_type is not None: - module.__class__ = override_type - # TODO(ryand): In the future, we should manage this flag on a per-module basis. - module.set_device_autocasting_enabled(True) +AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} + + +def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[torch.nn.Module]): + # HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an + # existing layer instance without calling __init__() on the original layer class. We achieve this by copying + # the attributes from the original layer instance to the new instance. + custom_layer = custom_layer_type.__new__(custom_layer_type) + # Note that we share the __dict__. + # TODO(ryand): In the future, we may want to do a shallow copy of the __dict__. + custom_layer.__dict__ = module_to_wrap.__dict__ - # model.apply(...) calls apply_custom_layers(...) on each module in the model. - model.apply(apply_custom_layers) + # Initialize the CustomModuleMixin fields. + CustomModuleMixin.__init__(custom_layer) # type: ignore + return custom_layer -def remove_custom_layers_from_model(model: torch.nn.Module): - # Invert AUTOCAST_MODULE_TYPE_MAPPING. - original_module_type_mapping = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} +def unwrap_custom_layer(custom_layer: torch.nn.Module, original_layer_type: type[torch.nn.Module]): + # HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an + # existing layer instance without calling __init__() on the original layer class. We achieve this by copying + # the attributes from the original layer instance to the new instance. + original_layer = original_layer_type.__new__(original_layer_type) + # Note that we share the __dict__. + # TODO(ryand): In the future, we may want to do a shallow copy of the __dict__ and strip out the CustomModuleMixin + # fields. + original_layer.__dict__ = custom_layer.__dict__ + return original_layer - def remove_custom_layers(module: torch.nn.Module): - override_type = original_module_type_mapping.get(type(module), None) + +def apply_custom_layers_to_model(module: torch.nn.Module): + for name, submodule in module.named_children(): + override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(submodule), None) if override_type is not None: - module.__class__ = override_type + custom_layer = wrap_custom_layer(submodule, override_type) + # TODO(ryand): In the future, we should manage this flag on a per-module basis. + custom_layer.set_device_autocasting_enabled(True) + setattr(module, name, custom_layer) + else: + # Recursively apply to submodules + apply_custom_layers_to_model(submodule) - # model.apply(...) calls remove_custom_layers(...) on each module in the model. - model.apply(remove_custom_layers) + +def remove_custom_layers_from_model(module: torch.nn.Module): + for name, submodule in module.named_children(): + override_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE.get(type(submodule), None) + if override_type is not None: + setattr(module, name, unwrap_custom_layer(submodule, override_type)) + else: + remove_custom_layers_from_model(submodule) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index f8fccda2c11..90940948f9d 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -5,7 +5,10 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( - apply_custom_layers_to_model, + AUTOCAST_MODULE_TYPE_MAPPING, + AUTOCAST_MODULE_TYPE_MAPPING_INVERSE, + unwrap_custom_layer, + wrap_custom_layer, ) from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( build_linear_8bit_lt_layer, @@ -98,15 +101,45 @@ def layer_to_device_via_state_dict(layer: torch.nn.Module, device: str): layer.load_state_dict(state_dict, assign=True) +def wrap_single_custom_layer(layer: torch.nn.Module): + custom_layer_type = AUTOCAST_MODULE_TYPE_MAPPING[type(layer)] + return wrap_custom_layer(layer, custom_layer_type) + + +def unwrap_single_custom_layer(layer: torch.nn.Module): + orig_layer_type = AUTOCAST_MODULE_TYPE_MAPPING_INVERSE[type(layer)] + return unwrap_custom_layer(layer, orig_layer_type) + + def test_isinstance(layer_under_test: LayerUnderTest): """Test that isinstance() and type() behave as expected after wrapping a layer in a custom layer.""" orig_layer, _, _ = layer_under_test orig_type = type(orig_layer) - apply_custom_layers_to_model(orig_layer) + custom_layer = wrap_single_custom_layer(orig_layer) + + assert isinstance(custom_layer, orig_type) + assert type(custom_layer) is not orig_type + + +def test_wrap_and_unwrap(layer_under_test: LayerUnderTest): + """Test that wrapping and unwrapping a layer behaves as expected.""" + orig_layer, _, _ = layer_under_test + orig_type = type(orig_layer) + + # Wrap the original layer and assert that attributes of the custom layer can be accessed. + custom_layer = wrap_single_custom_layer(orig_layer) + custom_layer.set_device_autocasting_enabled(True) + assert custom_layer._device_autocasting_enabled - assert isinstance(orig_layer, orig_type) - assert type(orig_layer) is not orig_type + # Unwrap the custom layer. + # Assert that the methods of the wrapped layer are no longer accessible. + unwrapped_layer = unwrap_single_custom_layer(custom_layer) + with pytest.raises(AttributeError): + _ = unwrapped_layer.set_device_autocasting_enabled(True) + # For now, we have chosen to allow attributes to persist. We may revisit this in the future. + assert unwrapped_layer._device_autocasting_enabled + assert type(unwrapped_layer) is orig_type @parameterize_all_devices @@ -120,7 +153,7 @@ def test_state_dict(device: str, layer_under_test: LayerUnderTest): # Wrap the original layer. custom_layer = copy.deepcopy(orig_layer) - apply_custom_layers_to_model(custom_layer) + custom_layer = wrap_single_custom_layer(custom_layer) custom_state_dict = custom_layer.state_dict() @@ -140,7 +173,7 @@ def test_load_state_dict(device: str, layer_under_test: LayerUnderTest): orig_layer.to(device) custom_layer = copy.deepcopy(orig_layer) - apply_custom_layers_to_model(custom_layer) + custom_layer = wrap_single_custom_layer(custom_layer) # Do a state dict roundtrip. orig_state_dict = orig_layer.state_dict() @@ -174,7 +207,7 @@ def test_inference_on_device(device: str, layer_under_test: LayerUnderTest): layer_to_device_via_state_dict(orig_layer, device) custom_layer = copy.deepcopy(orig_layer) - apply_custom_layers_to_model(custom_layer) + custom_layer = wrap_single_custom_layer(custom_layer) # Run inference with the original layer. x = layer_input.to(device) @@ -213,7 +246,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La # Wrap the original layer. custom_layer = copy.deepcopy(orig_layer) - apply_custom_layers_to_model(custom_layer) + custom_layer = wrap_single_custom_layer(custom_layer) # Inference should still fail with autocasting disabled. custom_layer.set_device_autocasting_enabled(False) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py index 8a2a0e1b614..e23cb25eb02 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -1,6 +1,10 @@ import pytest import torch +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) + if not torch.cuda.is_available(): pytest.skip("CUDA is not available", allow_module_level=True) else: @@ -44,8 +48,8 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer: y_quantized = linear_8bit_lt_layer(x) # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. - linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt - y_custom = linear_8bit_lt_layer(x) + custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt) + y_custom = custom_linear_8bit_lt_layer(x) # Assert that the quantized and custom layers produce the same output. assert torch.allclose(y_quantized, y_custom, atol=1e-5) @@ -67,9 +71,9 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu(linear_8bit_lt_layer: I linear_8bit_lt_layer(x) # Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it. - linear_8bit_lt_layer.__class__ = CustomInvokeLinear8bitLt - linear_8bit_lt_layer.set_device_autocasting_enabled(True) - y_custom = linear_8bit_lt_layer(x) + custom_linear_8bit_lt_layer = wrap_custom_layer(linear_8bit_lt_layer, CustomInvokeLinear8bitLt) + custom_linear_8bit_lt_layer.set_device_autocasting_enabled(True) + y_custom = custom_linear_8bit_lt_layer(x) # Assert that the quantized and custom layers produce the same output. assert torch.allclose(y_quantized, y_custom, atol=1e-5) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index 154a75b8588..17854597ec8 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( CustomInvokeLinearNF4, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 @@ -39,9 +42,9 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi y_quantized = linear_nf4_layer(x) # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. - linear_nf4_layer.__class__ = CustomInvokeLinearNF4 - linear_nf4_layer.set_device_autocasting_enabled(True) - y_custom = linear_nf4_layer(x) + custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4) + custom_linear_nf4_layer.set_device_autocasting_enabled(True) + y_custom = custom_linear_nf4_layer(x) # Assert that the quantized and custom layers produce the same output. assert torch.allclose(y_quantized, y_custom, atol=1e-5) @@ -66,18 +69,18 @@ def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLin linear_nf4_layer(x) # Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it. - linear_nf4_layer.__class__ = CustomInvokeLinearNF4 - linear_nf4_layer.set_device_autocasting_enabled(True) - y_custom = linear_nf4_layer(x) + custom_linear_nf4_layer = wrap_custom_layer(linear_nf4_layer, CustomInvokeLinearNF4) + custom_linear_nf4_layer.set_device_autocasting_enabled(True) + y_custom = custom_linear_nf4_layer(x) # Assert that the state dict (and the tensors that it references) are still on the CPU. assert all(v.device == torch.device("cpu") for v in state_dict.values()) # Assert that the weight, bias, and quant_state are all on the CPU. - assert linear_nf4_layer.weight.device == torch.device("cpu") - assert linear_nf4_layer.bias.device == torch.device("cpu") - assert linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu") - assert linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu") + assert custom_linear_nf4_layer.weight.device == torch.device("cpu") + assert custom_linear_nf4_layer.bias.device == torch.device("cpu") + assert custom_linear_nf4_layer.weight.quant_state.absmax.device == torch.device("cpu") + assert custom_linear_nf4_layer.weight.quant_state.code.device == torch.device("cpu") # Assert that the quantized and custom layers produce the same output. assert torch.allclose(y_quantized, y_custom, atol=1e-5) From e24e386a27ad1a964695fe6e6cb9df90666b0116 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Dec 2024 18:57:13 +0000 Subject: [PATCH 15/31] Add support for patches to CustomModuleMixin and add a single unit test (more to come). --- .../custom_modules/custom_linear.py | 13 +++- .../custom_modules/custom_module_mixin.py | 33 ++++++++++ .../custom_modules/utils.py | 30 +++++++++ .../custom_modules/test_all_custom_modules.py | 61 +++++++++++++++++++ 4 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index c1087c49079..58027b29512 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -4,16 +4,27 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomLinear(torch.nn.Linear, CustomModuleMixin): + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: + aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) + weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) + bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) + return torch.nn.functional.linear(input, weight, bias) + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return torch.nn.functional.linear(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(input) + elif self._device_autocasting_enabled: return self._autocast_forward(input) else: return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 75663317107..03c6d81e2a8 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -1,8 +1,14 @@ +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch + + class CustomModuleMixin: """A mixin class for custom modules that enables device autocasting of module parameters.""" def __init__(self): self._device_autocasting_enabled = False + self._patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] def set_device_autocasting_enabled(self, enabled: bool): """Pass True to enable autocasting of module parameters to the same device as the input tensor. Pass False to @@ -10,3 +16,30 @@ def set_device_autocasting_enabled(self, enabled: bool): not needed. """ self._device_autocasting_enabled = enabled + + def add_patch(self, patch: BaseLayerPatch, patch_weight: float): + """Add a patch to the sidecar wrapper.""" + self._patches_and_weights.append((patch, patch_weight)) + + def clear_patches(self): + """Clear all patches from the sidecar wrapper.""" + self._patches_and_weights = [] + + def _aggregate_patch_parameters( + self, patches_and_weights: list[tuple[BaseLayerPatch, float]] + ) -> dict[str, torch.Tensor]: + """Helper function that aggregates the parameters from all patches into a single dict.""" + params: dict[str, torch.Tensor] = {} + + for patch, patch_weight in patches_and_weights: + # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original + # module, this might fail or return incorrect results. + layer_params = patch.get_parameters(self, weight=patch_weight) + + for param_name, param_weight in layer_params.items(): + if param_name not in params: + params[param_name] = param_weight + else: + params[param_name] += param_weight + + return params diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py new file mode 100644 index 00000000000..60294d9e0c3 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/utils.py @@ -0,0 +1,30 @@ +from typing import overload + +import torch + + +@overload +def add_nullable_tensors(a: None, b: None) -> None: ... + + +@overload +def add_nullable_tensors(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ... + + +@overload +def add_nullable_tensors(a: torch.Tensor, b: None) -> torch.Tensor: ... + + +@overload +def add_nullable_tensors(a: None, b: torch.Tensor) -> torch.Tensor: ... + + +def add_nullable_tensors(a: torch.Tensor | None, b: torch.Tensor | None) -> torch.Tensor | None: + if a is None and b is None: + return None + elif a is None: + return b + elif b is None: + return a + else: + return a + b diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 90940948f9d..6911cfbb553 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -10,6 +10,8 @@ unwrap_custom_layer, wrap_custom_layer, ) +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.lora_layer import LoRALayer from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( build_linear_8bit_lt_layer, ) @@ -259,3 +261,62 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La assert custom_output.device.type == device assert torch.allclose(orig_output, custom_output) + + +LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool] + + +@pytest.fixture( + params=[ + "linear_lora", + ] +) +def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest: + """A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test.""" + layer_type = request.param + if layer_type == "linear_lora": + # Create a linear layer. + in_features = 10 + out_features = 20 + layer = torch.nn.Linear(in_features, out_features) + + # Create a LoRA layer. + rank = 4 + down = torch.randn(rank, in_features) + up = torch.randn(out_features, rank) + bias = torch.randn(out_features) + lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias) + + input = torch.randn(1, in_features) + return (layer, lora_layer, input, True) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") + + +@parameterize_all_devices +def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest): + layer, patch, input, supports_cpu_inference = layer_and_patch_under_test + + if device == "cpu" and not supports_cpu_inference: + pytest.skip("Layer does not support CPU inference.") + + # Move the layer, patch, and input to the device. + layer_to_device_via_state_dict(layer, device) + patch.to(torch.device(device)) + input = input.to(torch.device(device)) + + # Patch the LoRA layer into the linear layer. + weight = 0.7 + layer_patched = copy.deepcopy(layer) + parameters = patch.get_parameters(layer_patched, weight=weight) + for param_name, param_weight in parameters.items(): + getattr(layer_patched, param_name).data += param_weight + + # Wrap the original layer in a custom layer and add the patch to it as a sidecar. + custom_layer = wrap_single_custom_layer(layer) + custom_layer.add_patch(patch, weight) + + # Run inference with the original layer and the patched layer and assert they are equal. + output_patched = layer_patched(input) + output_custom = custom_layer(input) + assert torch.allclose(output_patched, output_custom) From 5ee7405f97805c9a5910e2616a37210127b8545e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Dec 2024 19:47:21 +0000 Subject: [PATCH 16/31] Add more unit tests for custom module LoRA patching: multiple LoRAs and ConcatenatedLoRALayers. --- .../custom_modules/test_all_custom_modules.py | 83 +++++++++++++++---- 1 file changed, 67 insertions(+), 16 deletions(-) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 6911cfbb553..2668ca61a44 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -11,6 +11,7 @@ wrap_custom_layer, ) from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer from invokeai.backend.patches.layers.lora_layer import LoRALayer from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( build_linear_8bit_lt_layer, @@ -263,18 +264,22 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La assert torch.allclose(orig_output, custom_output) -LayerAndPatchUnderTest = tuple[torch.nn.Module, BaseLayerPatch, torch.Tensor, bool] +LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool] @pytest.fixture( params=[ - "linear_lora", + "linear_single_lora", + "linear_multiple_loras", + "linear_concatenated_lora", ] ) def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest: """A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test.""" layer_type = request.param - if layer_type == "linear_lora": + torch.manual_seed(0) + + if layer_type == "linear_single_lora": # Create a linear layer. in_features = 10 out_features = 20 @@ -282,39 +287,85 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU # Create a LoRA layer. rank = 4 - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias) + lora_layer = LoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + input = torch.randn(1, in_features) + return (layer, [(lora_layer, 0.7)], input, True) + elif layer_type == "linear_multiple_loras": + # Create a linear layer. + rank = 4 + in_features = 10 + out_features = 20 + layer = torch.nn.Linear(in_features, out_features) + + lora_layer = LoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + lora_layer_2 = LoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + + input = torch.randn(1, in_features) + return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True) + elif layer_type == "linear_concatenated_lora": + # Create a linear layer. + in_features = 5 + sub_layer_out_features = [5, 10, 15] + layer = torch.nn.Linear(in_features, sum(sub_layer_out_features)) + + # Create a ConcatenatedLoRA layer. + rank = 4 + sub_layers: list[LoRALayer] = [] + for out_features in sub_layer_out_features: + down = torch.randn(rank, in_features) + up = torch.randn(out_features, rank) + bias = torch.randn(out_features) + sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)) + concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) input = torch.randn(1, in_features) - return (layer, lora_layer, input, True) + return (layer, [(concatenated_lora_layer, 0.7)], input, True) else: raise ValueError(f"Unsupported layer_type: {layer_type}") @parameterize_all_devices def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest): - layer, patch, input, supports_cpu_inference = layer_and_patch_under_test + layer, patches, input, supports_cpu_inference = layer_and_patch_under_test if device == "cpu" and not supports_cpu_inference: pytest.skip("Layer does not support CPU inference.") - # Move the layer, patch, and input to the device. + # Move the layer and input to the device. layer_to_device_via_state_dict(layer, device) - patch.to(torch.device(device)) input = input.to(torch.device(device)) # Patch the LoRA layer into the linear layer. - weight = 0.7 layer_patched = copy.deepcopy(layer) - parameters = patch.get_parameters(layer_patched, weight=weight) - for param_name, param_weight in parameters.items(): - getattr(layer_patched, param_name).data += param_weight + for patch, weight in patches: + patch.to(torch.device(device)) + parameters = patch.get_parameters(layer_patched, weight=weight) + for param_name, param_weight in parameters.items(): + module_param = getattr(layer_patched, param_name) + module_param.data += param_weight # Wrap the original layer in a custom layer and add the patch to it as a sidecar. custom_layer = wrap_single_custom_layer(layer) - custom_layer.add_patch(patch, weight) + for patch, weight in patches: + custom_layer.add_patch(patch, weight) # Run inference with the original layer and the patched layer and assert they are equal. output_patched = layer_patched(input) From ef970a1cdc6997cff10b396cd22d3817adfee601 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Dec 2024 21:00:47 +0000 Subject: [PATCH 17/31] Add support for FluxControlLoRALayer in CustomLinear layers and add a unit test for it. --- .../custom_modules/custom_linear.py | 77 +++++++++++++++++-- .../custom_modules/test_all_custom_modules.py | 38 +++++++-- 2 files changed, 102 insertions(+), 13 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index 58027b29512..e8335911092 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -4,17 +4,80 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( - add_nullable_tensors, -) +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer +from invokeai.backend.patches.layers.lora_layer import LoRALayer + + +def linear_lora_forward(input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: + """An optimized implementation of the residual calculation for a sidecar linear LoRALayer.""" + x = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x = torch.nn.functional.linear(x, lora_layer.mid) + x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) + x *= lora_weight * lora_layer.scale() + return x + + +def concatenated_lora_forward( + input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float +) -> torch.Tensor: + """An optimized implementation of the residual calculation for a sidecar ConcatenatedLoRALayer.""" + x_chunks: list[torch.Tensor] = [] + for lora_layer in concatenated_lora_layer.lora_layers: + x_chunk = torch.nn.functional.linear(input, lora_layer.down) + if lora_layer.mid is not None: + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) + x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) + x_chunk *= lora_weight * lora_layer.scale() + x_chunks.append(x_chunk) + + # TODO(ryand): Generalize to support concat_axis != 0. + assert concatenated_lora_layer.concat_axis == 0 + x = torch.cat(x_chunks, dim=-1) + return x + + +def autocast_linear_forward_sidecar_patches( + orig_module: torch.nn.Linear, input: torch.Tensor, patches_and_weights: list[tuple[BaseLayerPatch, float]] +) -> torch.Tensor: + """A function that runs a linear layer (quantized or non-quantized) with sidecar patches for a linear layer. + Compatible with both quantized and non-quantized Linear layers. + """ + # First, apply the original linear layer. + # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which + # change the linear layer's in_features. + orig_input = input + input = orig_input[..., : orig_module.in_features] + output = orig_module._autocast_forward(input) + + # Then, apply layers for which we have optimized implementations. + unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] + for patch, patch_weight in patches_and_weights: + if isinstance(patch, FluxControlLoRALayer): + # Note that we use the original input here, not the sliced input. + output += linear_lora_forward(orig_input, patch, patch_weight) + elif isinstance(patch, LoRALayer): + output += linear_lora_forward(input, patch, patch_weight) + elif isinstance(patch, ConcatenatedLoRALayer): + output += concatenated_lora_forward(input, patch, patch_weight) + else: + unprocessed_patches_and_weights.append((patch, patch_weight)) + + # Finally, apply any remaining patches. + if len(unprocessed_patches_and_weights) > 0: + aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights) + output += torch.nn.functional.linear( + input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) + ) + + return output class CustomLinear(torch.nn.Linear, CustomModuleMixin): def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) - bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) - return torch.nn.functional.linear(input, weight, bias) + return autocast_linear_forward_sidecar_patches(self, input, self._patches_and_weights) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 2668ca61a44..b01a744be65 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -10,9 +10,12 @@ unwrap_custom_layer, wrap_custom_layer, ) +from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer +from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage from tests.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.test_custom_invoke_linear_8_bit_lt import ( build_linear_8bit_lt_layer, ) @@ -272,6 +275,7 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La "linear_single_lora", "linear_multiple_loras", "linear_concatenated_lora", + "linear_flux_control_lora", ] ) def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest: @@ -338,6 +342,25 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU input = torch.randn(1, in_features) return (layer, [(concatenated_lora_layer, 0.7)], input, True) + elif layer_type == "linear_flux_control_lora": + # Create a linear layer. + orig_in_features = 10 + out_features = 40 + layer = torch.nn.Linear(orig_in_features, out_features) + + # Create a FluxControlLoRALayer. + patched_in_features = 20 + rank = 4 + lora_layer = FluxControlLoRALayer( + up=torch.randn(out_features, rank), + mid=None, + down=torch.randn(rank, patched_in_features), + alpha=1.0, + bias=torch.randn(out_features), + ) + + input = torch.randn(1, patched_in_features) + return (layer, [(lora_layer, 0.7)], input, True) else: raise ValueError(f"Unsupported layer_type: {layer_type}") @@ -356,18 +379,21 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU # Patch the LoRA layer into the linear layer. layer_patched = copy.deepcopy(layer) for patch, weight in patches: - patch.to(torch.device(device)) - parameters = patch.get_parameters(layer_patched, weight=weight) - for param_name, param_weight in parameters.items(): - module_param = getattr(layer_patched, param_name) - module_param.data += param_weight + LayerPatcher._apply_model_layer_patch( + module_to_patch=layer_patched, + module_to_patch_key="", + patch=patch, + patch_weight=weight, + original_weights=OriginalWeightsStorage(), + ) # Wrap the original layer in a custom layer and add the patch to it as a sidecar. custom_layer = wrap_single_custom_layer(layer) for patch, weight in patches: + patch.to(torch.device(device)) custom_layer.add_patch(patch, weight) # Run inference with the original layer and the patched layer and assert they are equal. output_patched = layer_patched(input) output_custom = custom_layer(input) - assert torch.allclose(output_patched, output_custom) + assert torch.allclose(output_patched, output_custom, atol=1e-6) From f2981979f90070deff013a713db46a8b9c560790 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Dec 2024 22:00:22 +0000 Subject: [PATCH 18/31] Get custom layer patches working with all quantized linear layer types. --- .../custom_invoke_linear_8_bit_lt.py | 10 +- .../custom_invoke_linear_nf4.py | 10 +- .../custom_modules/test_all_custom_modules.py | 142 ++++++++++++------ .../test_custom_invoke_linear_8_bit_lt.py | 9 +- .../test_custom_invoke_linear_nf4.py | 8 +- 5 files changed, 121 insertions(+), 58 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py index aa6acd31c51..2b9d8e9e98e 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_8_bit_lt.py @@ -2,6 +2,9 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + autocast_linear_forward_sidecar_patches, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -9,6 +12,9 @@ class CustomInvokeLinear8bitLt(InvokeLinear8bitLt, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: matmul_state = bnb.MatmulLtState() matmul_state.threshold = self.state.threshold @@ -30,7 +36,9 @@ def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: return self._autocast_forward(x) else: return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py index 60e987b3f36..89284d5509a 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_invoke_linear_nf4.py @@ -4,6 +4,9 @@ import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( + autocast_linear_forward_sidecar_patches, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) @@ -11,6 +14,9 @@ class CustomInvokeLinearNF4(InvokeLinearNF4, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + return autocast_linear_forward_sidecar_patches(self, x, self._patches_and_weights) + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self) @@ -48,7 +54,9 @@ def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: return self._autocast_forward(x) else: return super().forward(x) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index b01a744be65..666ea1d8cf2 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -25,13 +25,15 @@ from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor -def build_linear_layer_with_ggml_quantized_tensor(): - layer = torch.nn.Linear(32, 64) - ggml_quantized_weight = quantize_tensor(layer.weight, gguf.GGMLQuantizationType.Q8_0) - layer.weight = torch.nn.Parameter(ggml_quantized_weight) - ggml_quantized_bias = quantize_tensor(layer.bias, gguf.GGMLQuantizationType.Q8_0) - layer.bias = torch.nn.Parameter(ggml_quantized_bias) - return layer +def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | None = None): + if orig_layer is None: + orig_layer = torch.nn.Linear(32, 64) + + ggml_quantized_weight = quantize_tensor(orig_layer.weight, gguf.GGMLQuantizationType.Q8_0) + orig_layer.weight = torch.nn.Parameter(ggml_quantized_weight) + ggml_quantized_bias = quantize_tensor(orig_layer.bias, gguf.GGMLQuantizationType.Q8_0) + orig_layer.bias = torch.nn.Parameter(ggml_quantized_bias) + return orig_layer parameterize_all_devices = pytest.mark.parametrize( @@ -267,30 +269,29 @@ def test_inference_autocast_from_cpu_to_device(device: str, layer_under_test: La assert torch.allclose(orig_output, custom_output) -LayerAndPatchUnderTest = tuple[torch.nn.Module, list[tuple[BaseLayerPatch, float]], torch.Tensor, bool] +PatchUnderTest = tuple[list[tuple[BaseLayerPatch, float]], torch.Tensor] @pytest.fixture( params=[ - "linear_single_lora", - "linear_multiple_loras", - "linear_concatenated_lora", - "linear_flux_control_lora", + "single_lora", + "multiple_loras", + "concatenated_lora", + "flux_control_lora", ] ) -def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchUnderTest: - """A fixture that returns a tuple of (layer, input, supports_cpu_inference) for the layer under test.""" +def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: + """A fixture that returns a tuple of (patches, input) for the patch under test.""" layer_type = request.param torch.manual_seed(0) - if layer_type == "linear_single_lora": - # Create a linear layer. - in_features = 10 - out_features = 20 - layer = torch.nn.Linear(in_features, out_features) + # The assumed in/out features of the base linear layer. + in_features = 32 + out_features = 64 - # Create a LoRA layer. - rank = 4 + rank = 4 + + if layer_type == "single_lora": lora_layer = LoRALayer( up=torch.randn(out_features, rank), mid=None, @@ -299,14 +300,8 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU bias=torch.randn(out_features), ) input = torch.randn(1, in_features) - return (layer, [(lora_layer, 0.7)], input, True) - elif layer_type == "linear_multiple_loras": - # Create a linear layer. - rank = 4 - in_features = 10 - out_features = 20 - layer = torch.nn.Linear(in_features, out_features) - + return ([(lora_layer, 0.7)], input) + elif layer_type == "multiple_loras": lora_layer = LoRALayer( up=torch.randn(out_features, rank), mid=None, @@ -323,15 +318,11 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU ) input = torch.randn(1, in_features) - return (layer, [(lora_layer, 1.0), (lora_layer_2, 0.5)], input, True) - elif layer_type == "linear_concatenated_lora": - # Create a linear layer. - in_features = 5 - sub_layer_out_features = [5, 10, 15] - layer = torch.nn.Linear(in_features, sum(sub_layer_out_features)) + return ([(lora_layer, 1.0), (lora_layer_2, 0.5)], input) + elif layer_type == "concatenated_lora": + sub_layer_out_features = [16, 16, 32] # Create a ConcatenatedLoRA layer. - rank = 4 sub_layers: list[LoRALayer] = [] for out_features in sub_layer_out_features: down = torch.randn(rank, in_features) @@ -341,16 +332,10 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) input = torch.randn(1, in_features) - return (layer, [(concatenated_lora_layer, 0.7)], input, True) - elif layer_type == "linear_flux_control_lora": - # Create a linear layer. - orig_in_features = 10 - out_features = 40 - layer = torch.nn.Linear(orig_in_features, out_features) - + return ([(concatenated_lora_layer, 0.7)], input) + elif layer_type == "flux_control_lora": # Create a FluxControlLoRALayer. - patched_in_features = 20 - rank = 4 + patched_in_features = 40 lora_layer = FluxControlLoRALayer( up=torch.randn(out_features, rank), mid=None, @@ -360,17 +345,17 @@ def layer_and_patch_under_test(request: pytest.FixtureRequest) -> LayerAndPatchU ) input = torch.randn(1, patched_in_features) - return (layer, [(lora_layer, 0.7)], input, True) + return ([(lora_layer, 0.7)], input) else: raise ValueError(f"Unsupported layer_type: {layer_type}") @parameterize_all_devices -def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchUnderTest): - layer, patches, input, supports_cpu_inference = layer_and_patch_under_test +def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest): + patches, input = patch_under_test - if device == "cpu" and not supports_cpu_inference: - pytest.skip("Layer does not support CPU inference.") + # Build the base layer under test. + layer = torch.nn.Linear(32, 64) # Move the layer and input to the device. layer_to_device_via_state_dict(layer, device) @@ -397,3 +382,60 @@ def test_sidecar_patches(device: str, layer_and_patch_under_test: LayerAndPatchU output_patched = layer_patched(input) output_custom = custom_layer(input) assert torch.allclose(output_patched, output_custom, atol=1e-6) + + +@pytest.fixture( + params=[ + "linear_ggml_quantized", + "invoke_linear_8_bit_lt", + "invoke_linear_nf4", + ] +) +def quantized_linear_layer_under_test(request: pytest.FixtureRequest): + in_features = 32 + out_features = 64 + torch.manual_seed(0) + layer_type = request.param + orig_layer = torch.nn.Linear(in_features, out_features) + if layer_type == "linear_ggml_quantized": + return orig_layer, build_linear_layer_with_ggml_quantized_tensor(orig_layer) + elif layer_type == "invoke_linear_8_bit_lt": + return orig_layer, build_linear_8bit_lt_layer(orig_layer) + elif layer_type == "invoke_linear_nf4": + return orig_layer, build_linear_nf4_layer(orig_layer) + else: + raise ValueError(f"Unsupported layer_type: {layer_type}") + + +@parameterize_cuda_and_mps +def test_quantized_linear_sidecar_patches( + device: str, + quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module], + patch_under_test: PatchUnderTest, +): + """Test that patches can be applied to quantized linear layers and that the output is the same as when the patch is + applied to a non-quantized linear layer. + """ + patches, input = patch_under_test + + linear_layer, quantized_linear_layer = quantized_linear_layer_under_test + + # Move everything to the device. + layer_to_device_via_state_dict(linear_layer, device) + layer_to_device_via_state_dict(quantized_linear_layer, device) + input = input.to(torch.device(device)) + + # Wrap both layers in custom layers. + linear_layer_custom = wrap_single_custom_layer(linear_layer) + quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer) + + # Apply the patches to the custom layers. + for patch, weight in patches: + patch.to(torch.device(device)) + linear_layer_custom.add_patch(patch, weight) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with the original layer and the patched layer and assert they are equal. + output_linear_patched = linear_layer_custom(input) + output_quantized_patched = quantized_linear_layer_custom(input) + assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py index e23cb25eb02..9a225267fbf 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_8_bit_lt.py @@ -14,17 +14,20 @@ from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt -def build_linear_8bit_lt_layer(): +def build_linear_8bit_lt_layer(orig_layer: torch.nn.Linear | None = None): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") torch.manual_seed(1) - orig_layer = torch.nn.Linear(32, 64) + if orig_layer is None: + orig_layer = torch.nn.Linear(32, 64) orig_layer_state_dict = orig_layer.state_dict() # Prepare a quantized InvokeLinear8bitLt layer. - quantized_layer = InvokeLinear8bitLt(input_features=32, output_features=64, has_fp16_weights=False) + quantized_layer = InvokeLinear8bitLt( + input_features=orig_layer.in_features, output_features=orig_layer.out_features, has_fp16_weights=False + ) quantized_layer.load_state_dict(orig_layer_state_dict) quantized_layer.to("cuda") diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index 17854597ec8..3559ddea6cb 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -10,17 +10,19 @@ from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 -def build_linear_nf4_layer(): +def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") torch.manual_seed(1) - orig_layer = torch.nn.Linear(64, 16) + if orig_layer is None: + orig_layer = torch.nn.Linear(64, 16) + orig_layer_state_dict = orig_layer.state_dict() # Prepare a quantized InvokeLinearNF4 layer. - quantized_layer = InvokeLinearNF4(input_features=64, output_features=16) + quantized_layer = InvokeLinearNF4(input_features=orig_layer.in_features, output_features=orig_layer.out_features) quantized_layer.load_state_dict(orig_layer_state_dict) quantized_layer.to("cuda") From f692e217ea20ef7d2312c3bc92f0a0ed9f508a6e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Dec 2024 22:23:17 +0000 Subject: [PATCH 19/31] Add patch support to CustomConv1d and CustomConv2d (no unit tests yet). --- .../custom_modules/custom_conv1d.py | 13 ++++++++++++- .../custom_modules/custom_conv2d.py | 13 ++++++++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py index e0e38edd53c..d86a721e5a9 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -4,16 +4,27 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: + aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) + weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) + bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"]) + return torch.nn.functional.conv1d(input, weight, bias) + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return self._conv_forward(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(input) + elif self._device_autocasting_enabled: return self._autocast_forward(input) else: return super().forward(input) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index e8981e34c58..6067cef594c 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -4,16 +4,27 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): + def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: + aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) + weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) + bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"]) + return torch.nn.functional.conv2d(input, weight, bias) + def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) bias = cast_to_device(self.bias, input.device) return self._conv_forward(input, weight, bias) def forward(self, input: torch.Tensor) -> torch.Tensor: - if self._device_autocasting_enabled: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(input) + elif self._device_autocasting_enabled: return self._autocast_forward(input) else: return super().forward(input) From 93e76b61d6002ccf15f62e29f67b4d26ea3b3069 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 20:33:38 +0000 Subject: [PATCH 20/31] Add CustomFluxRMSNorm layer. --- .../custom_modules/custom_flux_rms_norm.py | 34 +++++++++++++++++++ .../torch_module_autocast.py | 5 +++ .../custom_modules/test_all_custom_modules.py | 6 +++- 3 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py new file mode 100644 index 00000000000..ba894433c90 --- /dev/null +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py @@ -0,0 +1,34 @@ +import torch + +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, +) +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer + + +class CustomFluxRMSNorm(RMSNorm, CustomModuleMixin): + def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: + # Currently, CustomFluxRMSNorm layers only support patching with a single SetParameterLayer. + assert len(self._patches_and_weights) == 1 + patch, _patch_weight = self._patches_and_weights[0] + assert isinstance(patch, SetParameterLayer) + assert patch.param_name == "scale" + + # Apply the patch. + # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should + # be handled. + return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6) + + def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: + scale = cast_to_device(self.scale, x.device) + return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + return self._autocast_forward_with_patches(x) + elif self._device_autocasting_enabled: + return self._autocast_forward(x) + else: + return super().forward(x) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index b9fd58f464a..2d85e32370f 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -1,5 +1,6 @@ import torch +from invokeai.backend.flux.modules.layers import RMSNorm from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_conv1d import ( CustomConv1d, ) @@ -9,6 +10,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_embedding import ( CustomEmbedding, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import ( + CustomFluxRMSNorm, +) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_group_norm import ( CustomGroupNorm, ) @@ -25,6 +29,7 @@ torch.nn.Conv2d: CustomConv2d, torch.nn.GroupNorm: CustomGroupNorm, torch.nn.Embedding: CustomEmbedding, + RMSNorm: CustomFluxRMSNorm, } try: diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 666ea1d8cf2..25a881952c5 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -4,6 +4,7 @@ import pytest import torch +from invokeai.backend.flux.modules.layers import RMSNorm from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( AUTOCAST_MODULE_TYPE_MAPPING, AUTOCAST_MODULE_TYPE_MAPPING_INVERSE, @@ -68,6 +69,7 @@ def build_linear_layer_with_ggml_quantized_tensor(orig_layer: torch.nn.Linear | "conv2d", "group_norm", "embedding", + "flux_rms_norm", "linear_with_ggml_quantized_tensor", "invoke_linear_8_bit_lt", "invoke_linear_nf4", @@ -86,6 +88,8 @@ def layer_under_test(request: pytest.FixtureRequest) -> LayerUnderTest: return (torch.nn.GroupNorm(2, 8), torch.randn(1, 8, 5), True) elif layer_type == "embedding": return (torch.nn.Embedding(4, 8), torch.tensor([0, 1], dtype=torch.long), True) + elif layer_type == "flux_rms_norm": + return (RMSNorm(8), torch.randn(1, 8), True) elif layer_type == "linear_with_ggml_quantized_tensor": return (build_linear_layer_with_ggml_quantized_tensor(), torch.randn(1, 32), True) elif layer_type == "invoke_linear_8_bit_lt": @@ -351,7 +355,7 @@ def patch_under_test(request: pytest.FixtureRequest) -> PatchUnderTest: @parameterize_all_devices -def test_linear_sidecar_patches(device: str, layer_type: str, patch_under_test: PatchUnderTest): +def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest): patches, input = patch_under_test # Build the base layer under test. From 918f541af8660d15be1bc4fae3e5b42743fef60e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 20:44:48 +0000 Subject: [PATCH 21/31] Add unit test for a SetParameterLayer patch applied to a CustomFluxRMSNorm layer. --- .../torch_module_autocast.py | 7 ++++- .../test_custom_flux_rms_norm.py | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 2d85e32370f..73d5ec1ee58 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -1,3 +1,5 @@ +from typing import TypeVar + import torch from invokeai.backend.flux.modules.layers import RMSNorm @@ -52,7 +54,10 @@ AUTOCAST_MODULE_TYPE_MAPPING_INVERSE = {v: k for k, v in AUTOCAST_MODULE_TYPE_MAPPING.items()} -def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[torch.nn.Module]): +T = TypeVar("T", bound=torch.nn.Module) + + +def wrap_custom_layer(module_to_wrap: torch.nn.Module, custom_layer_type: type[T]) -> T: # HACK(ryand): We use custom initialization logic so that we can initialize a new custom layer instance from an # existing layer instance without calling __init__() on the original layer class. We achieve this by copying # the attributes from the original layer instance to the new instance. diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py new file mode 100644 index 00000000000..05e15302d50 --- /dev/null +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_flux_rms_norm.py @@ -0,0 +1,31 @@ +import torch + +from invokeai.backend.flux.modules.layers import RMSNorm +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_flux_rms_norm import ( + CustomFluxRMSNorm, +) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + wrap_custom_layer, +) +from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer + + +def test_custom_flux_rms_norm_patch(): + """Test a SetParameterLayer patch on a CustomFluxRMSNorm layer.""" + # Create a RMSNorm layer. + dim = 8 + rms_norm = RMSNorm(dim) + + # Create a SetParameterLayer. + new_scale = torch.randn(dim) + set_parameter_layer = SetParameterLayer("scale", new_scale) + + # Wrap the RMSNorm layer in a CustomFluxRMSNorm layer. + custom_flux_rms_norm = wrap_custom_layer(rms_norm, CustomFluxRMSNorm) + custom_flux_rms_norm.add_patch(set_parameter_layer, 1.0) + + # Run the CustomFluxRMSNorm layer. + input = torch.randn(1, dim) + expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) + output_custom = custom_flux_rms_norm(input) + assert torch.allclose(output_custom, expected_output, atol=1e-6) From 20acfc9a00c82fd7496785b4136ca64210755f32 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 20:49:17 +0000 Subject: [PATCH 22/31] Raise in CustomEmbedding and CustomGroupNorm if a patch is applied. --- .../torch_module_autocast/custom_modules/custom_embedding.py | 3 +++ .../torch_module_autocast/custom_modules/custom_group_norm.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py index e6f0c5df21f..e622b678fa4 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_embedding.py @@ -20,6 +20,9 @@ def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: ) def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + raise RuntimeError("Embedding layers do not support patches") + if self._device_autocasting_enabled: return self._autocast_forward(input) else: diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py index 66a46ac7ea2..d02e2d533f1 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_group_norm.py @@ -13,6 +13,9 @@ def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, input: torch.Tensor) -> torch.Tensor: + if len(self._patches_and_weights) > 0: + raise RuntimeError("GroupNorm layers do not support patches") + if self._device_autocasting_enabled: return self._autocast_forward(input) else: From 2855bb6b41ca267a87a87ce06d2c425f72ad3b7d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sat, 28 Dec 2024 21:12:53 +0000 Subject: [PATCH 23/31] Update BaseLayerPatch.get_parameters(...) to accept a dict of orig_parameters rather than orig_module. This will enable compatibility between patching and cpu->gpu streaming. --- .../custom_modules/custom_module_mixin.py | 6 +++--- invokeai/backend/patches/layer_patcher.py | 4 +++- invokeai/backend/patches/layers/base_layer_patch.py | 2 +- .../backend/patches/layers/concatenated_lora_layer.py | 2 +- .../backend/patches/layers/flux_control_lora_layer.py | 6 +++--- invokeai/backend/patches/layers/lora_layer_base.py | 10 +++++----- invokeai/backend/patches/layers/set_parameter_layer.py | 4 ++-- .../patches/sidecar_wrappers/base_sidecar_wrapper.py | 6 ++++-- .../patches/layers/test_flux_control_lora_layer.py | 2 +- tests/backend/patches/layers/test_lora_layer.py | 2 +- .../backend/patches/layers/test_set_parameter_layer.py | 2 +- 11 files changed, 25 insertions(+), 21 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 03c6d81e2a8..58b3a610a05 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -32,9 +32,9 @@ def _aggregate_patch_parameters( params: dict[str, torch.Tensor] = {} for patch, patch_weight in patches_and_weights: - # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # module, this might fail or return incorrect results. - layer_params = patch.get_parameters(self, weight=patch_weight) + # TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original + # parameters, this might fail or return incorrect results. + layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore for param_name, param_weight in layer_params.items(): if param_name not in params: diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index d7f6bea166b..0eaad184e2c 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -166,7 +166,9 @@ def _apply_model_layer_patch( # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. - for param_name, param_weight in patch.get_parameters(module_to_patch, weight=patch_weight).items(): + for param_name, param_weight in patch.get_parameters( + dict(module_to_patch.named_parameters(recurse=False)), weight=patch_weight + ).items(): param_key = module_to_patch_key + "." + param_name module_param = module_to_patch.get_parameter(param_name) diff --git a/invokeai/backend/patches/layers/base_layer_patch.py b/invokeai/backend/patches/layers/base_layer_patch.py index 5eb04864c83..f6f0289a906 100644 --- a/invokeai/backend/patches/layers/base_layer_patch.py +++ b/invokeai/backend/patches/layers/base_layer_patch.py @@ -5,7 +5,7 @@ class BaseLayerPatch(ABC): @abstractmethod - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: """Get the parameter residual updates that should be applied to the original parameters. Parameters omitted from the returned dict are not updated. """ diff --git a/invokeai/backend/patches/layers/concatenated_lora_layer.py b/invokeai/backend/patches/layers/concatenated_lora_layer.py index a098a9e61be..a699a47433d 100644 --- a/invokeai/backend/patches/layers/concatenated_lora_layer.py +++ b/invokeai/backend/patches/layers/concatenated_lora_layer.py @@ -30,7 +30,7 @@ def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: layer_weights = [lora_layer.get_weight(None) * lora_layer.scale() for lora_layer in self.lora_layers] # pyright: ignore[reportArgumentType] return torch.cat(layer_weights, dim=self.concat_axis) - def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]: # TODO(ryand): Currently, we pass orig_bias=None to the sub-layers. If we want to support sub-layers that # require this value, we will need to implement chunking of the original bias tensor here. # Note that we must apply the sub-layer scales here. diff --git a/invokeai/backend/patches/layers/flux_control_lora_layer.py b/invokeai/backend/patches/layers/flux_control_lora_layer.py index 142336a00a2..ad592456a9d 100644 --- a/invokeai/backend/patches/layers/flux_control_lora_layer.py +++ b/invokeai/backend/patches/layers/flux_control_lora_layer.py @@ -8,11 +8,11 @@ class FluxControlLoRALayer(LoRALayer): shapes don't match. """ - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: """This overrides the base class behavior to skip the reshaping step.""" scale = self.scale() - params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} - bias = self.get_bias(orig_module.bias) + params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)} + bias = self.get_bias(orig_parameters.get("bias", None)) if bias is not None: params["bias"] = bias * (weight * scale) diff --git a/invokeai/backend/patches/layers/lora_layer_base.py b/invokeai/backend/patches/layers/lora_layer_base.py index 13669ad5d3d..123e5afa2c4 100644 --- a/invokeai/backend/patches/layers/lora_layer_base.py +++ b/invokeai/backend/patches/layers/lora_layer_base.py @@ -54,19 +54,19 @@ def scale(self) -> float: def get_weight(self, orig_weight: torch.Tensor) -> torch.Tensor: raise NotImplementedError() - def get_bias(self, orig_bias: torch.Tensor) -> Optional[torch.Tensor]: + def get_bias(self, orig_bias: torch.Tensor | None) -> Optional[torch.Tensor]: return self.bias - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: scale = self.scale() - params = {"weight": self.get_weight(orig_module.weight) * (weight * scale)} - bias = self.get_bias(orig_module.bias) + params = {"weight": self.get_weight(orig_parameters["weight"]) * (weight * scale)} + bias = self.get_bias(orig_parameters.get("bias", None)) if bias is not None: params["bias"] = bias * (weight * scale) # Reshape all params to match the original module's shape. for param_name, param_weight in params.items(): - orig_param = orig_module.get_parameter(param_name) + orig_param = orig_parameters[param_name] if param_weight.shape != orig_param.shape: params[param_name] = param_weight.reshape(orig_param.shape) diff --git a/invokeai/backend/patches/layers/set_parameter_layer.py b/invokeai/backend/patches/layers/set_parameter_layer.py index f0ae461f4d3..1b7fe94d366 100644 --- a/invokeai/backend/patches/layers/set_parameter_layer.py +++ b/invokeai/backend/patches/layers/set_parameter_layer.py @@ -14,10 +14,10 @@ def __init__(self, param_name: str, weight: torch.Tensor): self.weight = weight self.param_name = param_name - def get_parameters(self, orig_module: torch.nn.Module, weight: float) -> dict[str, torch.Tensor]: + def get_parameters(self, orig_parameters: dict[str, torch.Tensor], weight: float) -> dict[str, torch.Tensor]: # Note: We intentionally ignore the weight parameter here. This matches the behavior in the official FLUX # Control LoRA implementation. - diff = self.weight - orig_module.get_parameter(self.param_name) + diff = self.weight - orig_parameters[self.param_name] return {self.param_name: diff} def to(self, device: torch.device | None = None, dtype: torch.dtype | None = None): diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py index c22525bc95a..46d69bbe915 100644 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py @@ -39,8 +39,10 @@ def _aggregate_patch_parameters( for patch, patch_weight in patches_and_weights: # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # module, this might fail or return incorrect results. - layer_params = patch.get_parameters(self._orig_module, weight=patch_weight) + # parameters, this might fail or return incorrect results. + layer_params = patch.get_parameters( + dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight + ) for param_name, param_weight in layer_params.items(): if param_name not in params: diff --git a/tests/backend/patches/layers/test_flux_control_lora_layer.py b/tests/backend/patches/layers/test_flux_control_lora_layer.py index 00590c35149..129fcfcb4e3 100644 --- a/tests/backend/patches/layers/test_flux_control_lora_layer.py +++ b/tests/backend/patches/layers/test_flux_control_lora_layer.py @@ -18,7 +18,7 @@ def test_flux_control_lora_layer_get_parameters(): orig_module = torch.nn.Linear(small_in_features, out_features) # Test that get_parameters() behaves as expected in spite of the difference in in_features shapes. - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert "weight" in params assert params["weight"].shape == (out_features, big_in_features) assert params["weight"].allclose(torch.ones(out_features, big_in_features) * alpha) diff --git a/tests/backend/patches/layers/test_lora_layer.py b/tests/backend/patches/layers/test_lora_layer.py index 34f62c3bcf2..c0971fb9a14 100644 --- a/tests/backend/patches/layers/test_lora_layer.py +++ b/tests/backend/patches/layers/test_lora_layer.py @@ -107,7 +107,7 @@ def test_lora_layer_get_parameters(): # Create mock original module orig_module = torch.nn.Linear(in_features, out_features) - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert "weight" in params assert params["weight"].shape == orig_module.weight.shape assert params["weight"].allclose(torch.ones(out_features, in_features) * alpha) diff --git a/tests/backend/patches/layers/test_set_parameter_layer.py b/tests/backend/patches/layers/test_set_parameter_layer.py index 0bca0293f53..bdf8e337494 100644 --- a/tests/backend/patches/layers/test_set_parameter_layer.py +++ b/tests/backend/patches/layers/test_set_parameter_layer.py @@ -10,7 +10,7 @@ def test_set_parameter_layer_get_parameters(): target_weight = torch.randn(8, 4) layer = SetParameterLayer(param_name="weight", weight=target_weight) - params = layer.get_parameters(orig_module, weight=1.0) + params = layer.get_parameters(dict(orig_module.named_parameters(recurse=False)), weight=1.0) assert len(params) == 1 new_weight = orig_module.weight + params["weight"] assert torch.allclose(new_weight, target_weight) From 0525f967c2de075e04152fdd7049d04e07c8be4b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 00:22:37 +0000 Subject: [PATCH 24/31] Fix the _autocast_forward_with_patches() function for CustomConv1d and CustomConv2d. --- .../torch_module_autocast/custom_modules/custom_conv1d.py | 6 +++--- .../torch_module_autocast/custom_modules/custom_conv2d.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py index d86a721e5a9..ba643574062 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -12,9 +12,9 @@ class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) - bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"]) - return torch.nn.functional.conv1d(input, weight, bias) + weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index 6067cef594c..98b6c520167 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -12,9 +12,9 @@ class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - weight = add_nullable_tensors(self.weight, aggregated_param_residuals["weight"]) - bias = add_nullable_tensors(self.bias, aggregated_param_residuals["bias"]) - return torch.nn.functional.conv2d(input, weight, bias) + weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) From 6d49ee839c3d7673b01cff6131ffd878c6c9a004 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 01:18:30 +0000 Subject: [PATCH 25/31] Switch the LayerPatcher to use 'custom modules' to manage layer patching. --- .../cached_model_with_partial_load.py | 23 +++--- .../load/model_cache/model_cache.py | 7 ++ .../custom_modules/custom_module_mixin.py | 12 +++- .../torch_module_autocast.py | 6 +- invokeai/backend/patches/layer_patcher.py | 39 ++++------ .../test_cached_model_with_partial_load.py | 72 +++++++++---------- .../custom_modules/test_all_custom_modules.py | 37 ++++++++++ .../test_torch_module_autocast.py | 4 +- tests/backend/patches/test_layer_patcher.py | 28 +++++--- 9 files changed, 137 insertions(+), 91 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py index ab1a62db461..a5e1e3d5398 100644 --- a/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py +++ b/invokeai/backend/model_manager/load/model_cache/cached_model/cached_model_with_partial_load.py @@ -1,9 +1,7 @@ import torch -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( - AUTOCAST_MODULE_TYPE_MAPPING, - apply_custom_layers_to_model, - remove_custom_layers_from_model, +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( + CustomModuleMixin, ) from invokeai.backend.util.calc_tensor_size import calc_tensor_size from invokeai.backend.util.logging import InvokeAILogger @@ -45,10 +43,10 @@ def __init__(self, model: torch.nn.Module, compute_device: torch.device): def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]: """Find all modules that support autocasting.""" - return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING} + return {n: m for n, m in self._model.named_modules() if isinstance(m, CustomModuleMixin)} # type: ignore def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]: - keys_in_modules_that_do_not_support_autocast = set() + keys_in_modules_that_do_not_support_autocast: set[str] = set() for key in self._cpu_state_dict.keys(): for module_name in self._modules_that_support_autocast.keys(): if key.startswith(module_name): @@ -70,6 +68,11 @@ def _move_non_persistent_buffers_to_device(self, device: torch.device): if name in module._non_persistent_buffers_set: module._buffers[name] = buffer.to(device, copy=True) + def _set_autocast_enabled_in_all_modules(self, enabled: bool): + """Set autocast_enabled flag in all modules that support device autocasting.""" + for module in self._modules_that_support_autocast.values(): + module.set_device_autocasting_enabled(enabled) + @property def model(self) -> torch.nn.Module: return self._model @@ -114,7 +117,7 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: cur_state_dict = self._model.state_dict() - # First, process the keys *must* be loaded into VRAM. + # First, process the keys that *must* be loaded into VRAM. for key in self._keys_in_modules_that_do_not_support_autocast: param = cur_state_dict[key] if param.device.type == self._compute_device.type: @@ -157,10 +160,10 @@ def partial_load_to_vram(self, vram_bytes_to_load: int) -> int: self._cur_vram_bytes += vram_bytes_loaded if fully_loaded: - remove_custom_layers_from_model(self._model) + self._set_autocast_enabled_in_all_modules(False) # TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync. else: - apply_custom_layers_to_model(self._model) + self._set_autocast_enabled_in_all_modules(True) # Move all non-persistent buffers to the compute device. These are a weird edge case and do not participate in # the vram_bytes_loaded tracking. @@ -197,5 +200,5 @@ def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int: # We may have gone from a fully-loaded model to a partially-loaded model, so we need to reapply the custom # layers. - apply_custom_layers_to_model(self._model) + self._set_autocast_enabled_in_all_modules(True) return vram_bytes_freed diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache.py b/invokeai/backend/model_manager/load/model_cache/model_cache.py index f61e2963a76..dbc3670c95c 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache.py @@ -13,6 +13,9 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff from invokeai.backend.model_manager.load.model_cache.cache_record import CacheRecord from invokeai.backend.model_manager.load.model_cache.cache_stats import CacheStats +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.logging import InvokeAILogger @@ -143,6 +146,10 @@ def put( size = calc_model_size_by_data(self._logger, model) self.make_room(size) + # Inject custom modules into the model. + if isinstance(model, torch.nn.Module): + apply_custom_layers_to_model(model) + running_on_cpu = self._execution_device == torch.device("cpu") state_dict = model.state_dict() if isinstance(model, torch.nn.Module) and not running_on_cpu else None cache_record = CacheRecord(key=key, model=model, device=self._storage_device, state_dict=state_dict, size=size) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 58b3a610a05..494d6f0dd47 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -17,14 +17,22 @@ def set_device_autocasting_enabled(self, enabled: bool): """ self._device_autocasting_enabled = enabled + def is_device_autocasting_enabled(self) -> bool: + """Check if device autocasting is enabled for the module.""" + return self._device_autocasting_enabled + def add_patch(self, patch: BaseLayerPatch, patch_weight: float): - """Add a patch to the sidecar wrapper.""" + """Add a patch to the module.""" self._patches_and_weights.append((patch, patch_weight)) def clear_patches(self): - """Clear all patches from the sidecar wrapper.""" + """Clear all patches from the module.""" self._patches_and_weights = [] + def get_num_patches(self) -> int: + """Get the number of patches in the module.""" + return len(self._patches_and_weights) + def _aggregate_patch_parameters( self, patches_and_weights: list[tuple[BaseLayerPatch, float]] ) -> dict[str, torch.Tensor]: diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py index 73d5ec1ee58..0e271eaec5a 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py @@ -83,17 +83,17 @@ def unwrap_custom_layer(custom_layer: torch.nn.Module, original_layer_type: type return original_layer -def apply_custom_layers_to_model(module: torch.nn.Module): +def apply_custom_layers_to_model(module: torch.nn.Module, device_autocasting_enabled: bool = False): for name, submodule in module.named_children(): override_type = AUTOCAST_MODULE_TYPE_MAPPING.get(type(submodule), None) if override_type is not None: custom_layer = wrap_custom_layer(submodule, override_type) # TODO(ryand): In the future, we should manage this flag on a per-module basis. - custom_layer.set_device_autocasting_enabled(True) + custom_layer.set_device_autocasting_enabled(device_autocasting_enabled) setattr(module, name, custom_layer) else: # Recursively apply to submodules - apply_custom_layers_to_model(submodule) + apply_custom_layers_to_model(submodule, device_autocasting_enabled) def remove_custom_layers_from_model(module: torch.nn.Module): diff --git a/invokeai/backend/patches/layer_patcher.py b/invokeai/backend/patches/layer_patcher.py index 0eaad184e2c..463d753b9dd 100644 --- a/invokeai/backend/patches/layer_patcher.py +++ b/invokeai/backend/patches/layer_patcher.py @@ -7,8 +7,6 @@ from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.patches.pad_with_zeros import pad_with_zeros -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.utils import wrap_module_with_sidecar_wrapper from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage @@ -32,7 +30,7 @@ def apply_smart_model_patches( # original_weights are stored for unpatching layers that are directly patched. original_weights = OriginalWeightsStorage(cached_weights) - # original_modules are stored for unpatching layers that are wrapped in a LoRASidecarWrapper. + # original_modules are stored for unpatching layers that are wrapped. original_modules: dict[str, torch.nn.Module] = {} try: for patch, patch_weight in patches: @@ -55,12 +53,10 @@ def apply_smart_model_patches( cur_param = model.get_parameter(param_key) cur_param.data = weight.to(dtype=cur_param.dtype, device=cur_param.device, copy=True) - # Restore LoRASidecarWrapper modules. + # Clear patches from all patched modules. # Note: This logic assumes no nested modules in original_modules. - for module_key, orig_module in original_modules.items(): - module_parent_key, module_name = LayerPatcher._split_parent_key(module_key) - parent_module = model.get_submodule(module_parent_key) - LayerPatcher._set_submodule(parent_module, module_name, orig_module) + for orig_module in original_modules.values(): + orig_module.clear_patches() @staticmethod @torch.no_grad() @@ -97,11 +93,11 @@ def apply_smart_model_patch( model, layer_key[prefix_len:], layer_key_is_flattened=layer_keys_are_flattened ) - # Decide whether to use direct patching or a sidecar wrapper. + # Decide whether to use direct patching or a sidecar patch. # Direct patching is preferred, because it results in better runtime speed. # Reasons to use sidecar patching: # - The module is quantized, so the caller passed force_sidecar_patching=True. - # - The module is already wrapped in a BaseSidecarWrapper. + # - The module already has sidecar patches. # - The module is on the CPU (and we don't want to store a second full copy of the original weights on the # CPU, since this would double the RAM usage) # NOTE: For now, we don't check if the layer is quantized here. We assume that this is checked in the caller @@ -115,14 +111,13 @@ def apply_smart_model_patch( use_sidecar_patching = False elif force_sidecar_patching: use_sidecar_patching = True - elif isinstance(module, BaseSidecarWrapper): + elif module.get_num_patches() > 0: use_sidecar_patching = True elif LayerPatcher._is_any_part_of_layer_on_cpu(module): use_sidecar_patching = True if use_sidecar_patching: LayerPatcher._apply_model_layer_wrapper_patch( - model=model, module_to_patch=module, module_to_patch_key=module_key, patch=layer, @@ -194,7 +189,6 @@ def _apply_model_layer_patch( @staticmethod @torch.no_grad() def _apply_model_layer_wrapper_patch( - model: torch.nn.Module, module_to_patch: torch.nn.Module, module_to_patch_key: str, patch: BaseLayerPatch, @@ -202,25 +196,16 @@ def _apply_model_layer_wrapper_patch( original_modules: dict[str, torch.nn.Module], dtype: torch.dtype, ): - """Apply a single LoRA wrapper patch to a model.""" - # Replace the original module with a BaseSidecarWrapper if it has not already been done. - if not isinstance(module_to_patch, BaseSidecarWrapper): - wrapped_module = wrap_module_with_sidecar_wrapper(orig_module=module_to_patch) - original_modules[module_to_patch_key] = module_to_patch - module_parent_key, module_name = LayerPatcher._split_parent_key(module_to_patch_key) - module_parent = model.get_submodule(module_parent_key) - LayerPatcher._set_submodule(module_parent, module_name, wrapped_module) - else: - assert module_to_patch_key in original_modules - wrapped_module = module_to_patch - + """Apply a single LoRA wrapper patch to a module.""" # Move the LoRA layer to the same device/dtype as the orig module. first_param = next(module_to_patch.parameters()) device = first_param.device patch.to(device=device, dtype=dtype) - # Add the patch to the sidecar wrapper. - wrapped_module.add_patch(patch, patch_weight) + if module_to_patch_key not in original_modules: + original_modules[module_to_patch_key] = module_to_patch + + module_to_patch.add_patch(patch, patch_weight) @staticmethod def _split_parent_key(module_key: str) -> tuple[str, str]: diff --git a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py index 00ce27d580f..4fae046cf88 100644 --- a/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py +++ b/tests/backend/model_manager/load/model_cache/cached_model/test_cached_model_with_partial_load.py @@ -1,20 +1,27 @@ import itertools +import pytest import torch from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_linear import ( - CustomLinear, +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, ) from invokeai.backend.util.calc_tensor_size import calc_tensor_size from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda -@parameterize_mps_and_cuda -def test_cached_model_total_bytes(device: str): +@pytest.fixture +def model(): model = DummyModule() + apply_custom_layers_to_model(model) + return model + + +@parameterize_mps_and_cuda +def test_cached_model_total_bytes(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) linear1_numel = 10 * 32 + 32 linear2_numel = 32 * 64 + 64 @@ -24,8 +31,7 @@ def test_cached_model_total_bytes(device: str): @parameterize_mps_and_cuda -def test_cached_model_cur_vram_bytes(device: str): - model = DummyModule() +def test_cached_model_cur_vram_bytes(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) assert cached_model.cur_vram_bytes() == 0 @@ -39,8 +45,7 @@ def test_cached_model_cur_vram_bytes(device: str): @parameterize_mps_and_cuda -def test_cached_model_partial_load(device: str): - model = DummyModule() +def test_cached_model_partial_load(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() @@ -60,14 +65,13 @@ def test_cached_model_partial_load(device: str): if p.device.type == device and n != "buffer2" ) - # Check that the model's modules have been patched with CustomLinear layers. - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + # Check that the model's modules have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() @parameterize_mps_and_cuda -def test_cached_model_partial_unload(device: str): - model = DummyModule() +def test_cached_model_partial_unload(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() @@ -89,14 +93,13 @@ def test_cached_model_partial_unload(device: str): calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu" ) - # Check that the model's modules are still patched with CustomLinear layers. - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + # Check that the model's modules still have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() @parameterize_mps_and_cuda -def test_cached_model_full_load_and_unload(device: str): - model = DummyModule() +def test_cached_model_full_load_and_unload(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -109,8 +112,8 @@ def test_cached_model_full_load_and_unload(device: str): assert loaded_bytes == model_total_bytes assert loaded_bytes == cached_model.cur_vram_bytes() assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) - assert type(model.linear1) is torch.nn.Linear - assert type(model.linear2) is torch.nn.Linear + assert not model.linear1.is_device_autocasting_enabled() + assert not model.linear2.is_device_autocasting_enabled() # Full unload the model from VRAM. unloaded_bytes = cached_model.full_unload_from_vram() @@ -128,8 +131,7 @@ def test_cached_model_full_load_and_unload(device: str): @parameterize_mps_and_cuda -def test_cached_model_full_load_from_partial(device: str): - model = DummyModule() +def test_cached_model_full_load_from_partial(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -142,8 +144,8 @@ def test_cached_model_full_load_from_partial(device: str): assert loaded_bytes > 0 assert loaded_bytes < model_total_bytes assert loaded_bytes == cached_model.cur_vram_bytes() - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() # Full load the rest of the model into VRAM. loaded_bytes_2 = cached_model.full_load_to_vram() @@ -152,13 +154,12 @@ def test_cached_model_full_load_from_partial(device: str): assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes() assert loaded_bytes + loaded_bytes_2 == model_total_bytes assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers())) - assert type(model.linear1) is torch.nn.Linear - assert type(model.linear2) is torch.nn.Linear + assert not model.linear1.is_device_autocasting_enabled() + assert not model.linear2.is_device_autocasting_enabled() @parameterize_mps_and_cuda -def test_cached_model_full_unload_from_partial(device: str): - model = DummyModule() +def test_cached_model_full_unload_from_partial(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -186,8 +187,7 @@ def test_cached_model_full_unload_from_partial(device: str): @parameterize_mps_and_cuda -def test_cached_model_get_cpu_state_dict(device: str): - model = DummyModule() +def test_cached_model_get_cpu_state_dict(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. @@ -211,8 +211,7 @@ def test_cached_model_get_cpu_state_dict(device: str): @parameterize_mps_and_cuda -def test_cached_model_full_load_and_inference(device: str): - model = DummyModule() +def test_cached_model_full_load_and_inference(device: str, model: DummyModule): cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) # Model starts in CPU memory. model_total_bytes = cached_model.total_bytes() @@ -239,8 +238,7 @@ def test_cached_model_full_load_and_inference(device: str): @parameterize_mps_and_cuda -def test_cached_model_partial_load_and_inference(device: str): - model = DummyModule() +def test_cached_model_partial_load_and_inference(device: str, model: DummyModule): # Model starts in CPU memory. cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device)) model_total_bytes = cached_model.total_bytes() @@ -264,9 +262,9 @@ def test_cached_model_partial_load_and_inference(device: str): for n, p in itertools.chain(model.named_parameters(), model.named_buffers()) if p.device.type == device and n != "buffer2" ) - # Check that the model's modules have been patched with CustomLinear layers. - assert type(model.linear1) is CustomLinear - assert type(model.linear2) is CustomLinear + # Check that the model's modules have device autocasting enabled. + assert model.linear1.is_device_autocasting_enabled() + assert model.linear2.is_device_autocasting_enabled() # Run inference on the GPU. output2 = model(x.to(device)) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 25a881952c5..102ea3691f4 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -388,6 +388,43 @@ def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest): assert torch.allclose(output_patched, output_custom, atol=1e-6) +@parameterize_cuda_and_mps +# def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest): +# patches, input = patch_under_test + +# # Build the base layer under test. +# layer = torch.nn.Linear(32, 64) + +# # Move the layer and input to the device. +# layer_to_device_via_state_dict(layer, device) +# input = input.to(torch.device(device)) + +# # Wrap the original layer in a custom layer and add the patch to it. +# custom_layer = wrap_single_custom_layer(layer) +# for patch, weight in patches: +# patch.to(torch.device(device)) +# custom_layer.add_patch(patch, weight) + +# # Run inference with the custom layer on the device. +# expected_output = custom_layer(input) + +# # Move the custom layer to the CPU. +# layer_to_device_via_state_dict(custom_layer, "cpu") + +# # Move the patches to the CPU. +# custom_layer.clear_patches() +# for patch, weight in patches: +# patch.to(torch.device("cpu")) +# custom_layer.add_patch(patch, weight) + +# # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to +# # the device. +# autocast_output = custom_layer(input) +# assert autocast_output.device.type == device + +# assert torch.allclose(expected_output, autocast_output, atol=1e-6) + + @pytest.fixture( params=[ "linear_ggml_quantized", diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py index 65b9f66066d..1861597a633 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py @@ -72,7 +72,7 @@ def test_torch_module_autocast_linear_layer(device: torch.device, model: torch.n assert expected.device.type == "cpu" # Apply the custom layers to the model. - apply_custom_layers_to_model(model) + apply_custom_layers_to_model(model, device_autocasting_enabled=True) # Run the model on the device. autocast_result = model(x.to(device)) @@ -122,7 +122,7 @@ def test_torch_module_autocast_bnb_llm_int8_linear_layer(): # Move the model back to the CPU and add the custom layers to the model. model.to("cpu") - apply_custom_layers_to_model(model) + apply_custom_layers_to_model(model, device_autocasting_enabled=True) # Run inference with weights being streamed to the GPU. autocast_result = model(x.to("cuda")) diff --git a/tests/backend/patches/test_layer_patcher.py b/tests/backend/patches/test_layer_patcher.py index 06d64c05c27..84741fc60ae 100644 --- a/tests/backend/patches/test_layer_patcher.py +++ b/tests/backend/patches/test_layer_patcher.py @@ -4,10 +4,12 @@ from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import ( CachedModelWithPartialLoad, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( + apply_custom_layers_to_model, +) from invokeai.backend.patches.layer_patcher import LayerPatcher from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.patches.model_patch_raw import ModelPatchRaw -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper class DummyModuleWithOneLayer(torch.nn.Module): @@ -50,6 +52,7 @@ def test_apply_smart_model_patches( linear_out_features = 8 lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device=device, dtype=dtype) + apply_custom_layers_to_model(model) # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 @@ -89,11 +92,11 @@ def test_apply_smart_model_patches( force_sidecar_patching=force_sidecar_patching, ): if expect_sidecar_wrappers: - # There should be sidecar wrappers in the model. - assert isinstance(model.linear_layer_1, BaseSidecarWrapper) + # There should be sidecar patches in the model. + assert model.linear_layer_1.get_num_patches() == num_loras else: - # There should be no sidecar wrappers in the model. - assert not isinstance(model.linear_layer_1, BaseSidecarWrapper) + # There should be no sidecar patches in the model. + assert model.linear_layer_1.get_num_patches() == 0 torch.testing.assert_close(model.linear_layer_1.weight.data, expected_patched_linear_weight) # After patching, the patched model should still be on its original device. @@ -132,6 +135,7 @@ def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int): linear_out_features = 8 lora_rank = 2 model = DummyModuleWithTwoLayers(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + apply_custom_layers_to_model(model) cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device("cuda")) model_total_bytes = cached_model.total_bytes() assert cached_model.cur_vram_bytes() == 0 @@ -169,9 +173,9 @@ def test_apply_smart_lora_patches_to_partially_loaded_model(num_loras: int): # Patch the model and run inference during the patch. with LayerPatcher.apply_smart_model_patches(model=cached_model.model, patches=lora_models, prefix="", dtype=dtype): - # Check that the second layer is wrapped in a LoRASidecarWrapper, but the first layer is not. - assert not isinstance(cached_model.model.linear_layer_1, BaseSidecarWrapper) - assert isinstance(cached_model.model.linear_layer_2, BaseSidecarWrapper) + # Check that the second layer has sidecar patches, but the first layer does not. + assert cached_model.model.linear_layer_1.get_num_patches() == 0 + assert cached_model.model.linear_layer_2.get_num_patches() == num_loras output_during_patch = cached_model.model(input) @@ -194,6 +198,7 @@ def test_all_patching_methods_produce_same_output(num_loras: int): linear_out_features = 8 lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=dtype) + apply_custom_layers_to_model(model) # Initialize num_loras LoRA models with weights of 0.5. lora_weight = 0.5 @@ -242,6 +247,7 @@ def test_apply_smart_model_patches_change_device(): lora_dim = 2 # Initialize the model on the CPU. model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + apply_custom_layers_to_model(model) lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( @@ -265,8 +271,8 @@ def test_apply_smart_model_patches_change_device(): # After patching, the patched model should still be on the CPU. assert model.linear_layer_1.weight.data.device.type == "cpu" - # There should be no sidecar wrappers in the model. - assert not isinstance(model.linear_layer_1, BaseSidecarWrapper) + # There should be no sidecar patches in the model. + assert model.linear_layer_1.get_num_patches() == 0 # Move the model to the GPU. assert model.to("cuda") @@ -284,6 +290,8 @@ def test_apply_smart_model_patches_force_sidecar_and_direct_patching(): linear_out_features = 8 lora_rank = 2 model = DummyModuleWithOneLayer(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16) + apply_custom_layers_to_model(model) + lora_layers = { "linear_layer_1": LoRALayer.from_state_dict_values( values={ From a8bef596994b991d67bea4fa44c0c805be82757c Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 06:51:30 +0000 Subject: [PATCH 26/31] First pass at making custom layer patches work with weights streamed from the CPU to the GPU. --- .../custom_modules/custom_conv1d.py | 23 +++++--- .../custom_modules/custom_conv2d.py | 23 +++++--- .../custom_modules/custom_flux_rms_norm.py | 4 +- .../custom_modules/custom_linear.py | 15 ++++- .../custom_modules/custom_module_mixin.py | 16 +++++- .../custom_modules/test_all_custom_modules.py | 56 ++++++++++--------- 6 files changed, 92 insertions(+), 45 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py index ba643574062..b59b5a2aae5 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( - add_nullable_tensors, -) class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None)) - bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) - return self._conv_forward(input, weight, bias) + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + + # Prepare the original parameters for the patch aggregation. + orig_params = {"weight": weight, "bias": bias} + # Filter out None values. + orig_params = {k: v for k, v in orig_params.items() if v is not None} + + aggregated_param_residuals = self._aggregate_patch_parameters( + patches_and_weights=self._patches_and_weights, + orig_params=orig_params, + device=input.device, + ) + return self._conv_forward( + input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) + ) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index 98b6c520167..1077b47ed5e 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -4,17 +4,26 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( - add_nullable_tensors, -) class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - weight = add_nullable_tensors(self.weight, aggregated_param_residuals.get("weight", None)) - bias = add_nullable_tensors(self.bias, aggregated_param_residuals.get("bias", None)) - return self._conv_forward(input, weight, bias) + weight = cast_to_device(self.weight, input.device) + bias = cast_to_device(self.bias, input.device) + + # Prepare the original parameters for the patch aggregation. + orig_params = {"weight": weight, "bias": bias} + # Filter out None values. + orig_params = {k: v for k, v in orig_params.items() if v is not None} + + aggregated_param_residuals = self._aggregate_patch_parameters( + patches_and_weights=self._patches_and_weights, + orig_params=orig_params, + device=input.device, + ) + return self._conv_forward( + input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) + ) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py index ba894433c90..dccbe4af6c7 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_flux_rms_norm.py @@ -16,10 +16,12 @@ def _autocast_forward_with_patches(self, x: torch.Tensor) -> torch.Tensor: assert isinstance(patch, SetParameterLayer) assert patch.param_name == "scale" + scale = cast_to_device(patch.weight, x.device) + # Apply the patch. # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should # be handled. - return torch.nn.functional.rms_norm(x, patch.weight.shape, patch.weight, eps=1e-6) + return torch.nn.functional.rms_norm(x, scale.shape, scale, eps=1e-6) def _autocast_forward(self, x: torch.Tensor) -> torch.Tensor: scale = cast_to_device(self.scale, x.device) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py index e8335911092..7d5784563e3 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_linear.py @@ -1,3 +1,5 @@ +import copy + import torch from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device @@ -55,6 +57,10 @@ def autocast_linear_forward_sidecar_patches( # Then, apply layers for which we have optimized implementations. unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] for patch, patch_weight in patches_and_weights: + # Shallow copy the patch so that we can cast it to the target device without modifying the original patch. + patch = copy.copy(patch) + patch.to(input.device) + if isinstance(patch, FluxControlLoRALayer): # Note that we use the original input here, not the sliced input. output += linear_lora_forward(orig_input, patch, patch_weight) @@ -67,7 +73,14 @@ def autocast_linear_forward_sidecar_patches( # Finally, apply any remaining patches. if len(unprocessed_patches_and_weights) > 0: - aggregated_param_residuals = orig_module._aggregate_patch_parameters(unprocessed_patches_and_weights) + # Prepare the original parameters for the patch aggregation. + orig_params = {"weight": orig_module.weight, "bias": orig_module.bias} + # Filter out None values. + orig_params = {k: v for k, v in orig_params.items() if v is not None} + + aggregated_param_residuals = orig_module._aggregate_patch_parameters( + unprocessed_patches_and_weights, orig_params=orig_params, device=input.device + ) output += torch.nn.functional.linear( input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) ) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py index 494d6f0dd47..a7312517a48 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_module_mixin.py @@ -1,3 +1,5 @@ +import copy + import torch from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch @@ -34,15 +36,23 @@ def get_num_patches(self) -> int: return len(self._patches_and_weights) def _aggregate_patch_parameters( - self, patches_and_weights: list[tuple[BaseLayerPatch, float]] - ) -> dict[str, torch.Tensor]: + self, + patches_and_weights: list[tuple[BaseLayerPatch, float]], + orig_params: dict[str, torch.Tensor], + device: torch.device | None = None, + ): """Helper function that aggregates the parameters from all patches into a single dict.""" params: dict[str, torch.Tensor] = {} for patch, patch_weight in patches_and_weights: + if device is not None: + # Shallow copy the patch so that we can cast it to the target device without modifying the original patch. + patch = copy.copy(patch) + patch.to(device) + # TODO(ryand): `self` could be a quantized module. Depending on what the patch is doing with the original # parameters, this might fail or return incorrect results. - layer_params = patch.get_parameters(dict(self.named_parameters(recurse=False)), weight=patch_weight) # type: ignore + layer_params = patch.get_parameters(orig_params, weight=patch_weight) for param_name, param_weight in layer_params.items(): if param_name not in params: diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 102ea3691f4..273a4bf543c 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -389,40 +389,44 @@ def test_linear_sidecar_patches(device: str, patch_under_test: PatchUnderTest): @parameterize_cuda_and_mps -# def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest): -# patches, input = patch_under_test +def test_linear_sidecar_patches_with_autocast_from_cpu_to_device(device: str, patch_under_test: PatchUnderTest): + """Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and + when the layer is on the CPU and the patches are autocasted to the device. + """ + patches, input = patch_under_test -# # Build the base layer under test. -# layer = torch.nn.Linear(32, 64) + # Build the base layer under test. + layer = torch.nn.Linear(32, 64) -# # Move the layer and input to the device. -# layer_to_device_via_state_dict(layer, device) -# input = input.to(torch.device(device)) + # Move the layer and input to the device. + layer_to_device_via_state_dict(layer, device) + input = input.to(torch.device(device)) -# # Wrap the original layer in a custom layer and add the patch to it. -# custom_layer = wrap_single_custom_layer(layer) -# for patch, weight in patches: -# patch.to(torch.device(device)) -# custom_layer.add_patch(patch, weight) + # Wrap the original layer in a custom layer and add the patch to it. + custom_layer = wrap_single_custom_layer(layer) + for patch, weight in patches: + patch.to(torch.device(device)) + custom_layer.add_patch(patch, weight) -# # Run inference with the custom layer on the device. -# expected_output = custom_layer(input) + # Run inference with the custom layer on the device. + expected_output = custom_layer(input) -# # Move the custom layer to the CPU. -# layer_to_device_via_state_dict(custom_layer, "cpu") + # Move the custom layer to the CPU. + layer_to_device_via_state_dict(custom_layer, "cpu") -# # Move the patches to the CPU. -# custom_layer.clear_patches() -# for patch, weight in patches: -# patch.to(torch.device("cpu")) -# custom_layer.add_patch(patch, weight) + # Move the patches to the CPU. + custom_layer.clear_patches() + for patch, weight in patches: + patch.to(torch.device("cpu")) + custom_layer.add_patch(patch, weight) -# # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to -# # the device. -# autocast_output = custom_layer(input) -# assert autocast_output.device.type == device + # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to + # the device. + autocast_output = custom_layer(input) + assert autocast_output.device.type == device -# assert torch.allclose(expected_output, autocast_output, atol=1e-6) + # Assert that the outputs with and without autocasting are the same. + assert torch.allclose(expected_output, autocast_output, atol=1e-6) @pytest.fixture( From 52fc5a64d4c0a9f6d84fc240d1ae338e828e9ce9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 17:14:55 +0000 Subject: [PATCH 27/31] Add a unit test for a LoRA patch applied to a quantized linear layer with weights streamed from CPU to GPU. --- .../custom_modules/test_all_custom_modules.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py index 273a4bf543c..97062772341 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_all_custom_modules.py @@ -484,3 +484,47 @@ def test_quantized_linear_sidecar_patches( output_linear_patched = linear_layer_custom(input) output_quantized_patched = quantized_linear_layer_custom(input) assert torch.allclose(output_linear_patched, output_quantized_patched, rtol=0.2, atol=0.2) + + +@parameterize_cuda_and_mps +def test_quantized_linear_sidecar_patches_with_autocast_from_cpu_to_device( + device: str, + quantized_linear_layer_under_test: tuple[torch.nn.Module, torch.nn.Module], + patch_under_test: PatchUnderTest, +): + """Test that the output of a linear layer with sidecar patches is the same when the layer is on the device and + when the layer is on the CPU and the patches are autocasted to the device. + """ + patches, input = patch_under_test + + _, quantized_linear_layer = quantized_linear_layer_under_test + + # Move everything to the device. + layer_to_device_via_state_dict(quantized_linear_layer, device) + input = input.to(torch.device(device)) + + # Wrap the quantized linear layer in a custom layer and add the patch to it. + quantized_linear_layer_custom = wrap_single_custom_layer(quantized_linear_layer) + for patch, weight in patches: + patch.to(torch.device(device)) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with the custom layer on the device. + expected_output = quantized_linear_layer_custom(input) + + # Move the custom layer to the CPU. + layer_to_device_via_state_dict(quantized_linear_layer_custom, "cpu") + + # Move the patches to the CPU. + quantized_linear_layer_custom.clear_patches() + for patch, weight in patches: + patch.to(torch.device("cpu")) + quantized_linear_layer_custom.add_patch(patch, weight) + + # Run inference with an input on the device, and all layer weights on the CPU. The weights should be autocasted to + # the device. + autocast_output = quantized_linear_layer_custom(input) + assert autocast_output.device.type == device + + # Assert that the outputs with and without autocasting are the same. + assert torch.allclose(expected_output, autocast_output, atol=1e-6) From 6fd9b0a274aa55ea0a319be7bbc6caa09cabd4d3 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 17:33:08 +0000 Subject: [PATCH 28/31] Delete old sidecar wrapper implementation. This functionality has moved into the custom layers. --- .../patches/sidecar_wrappers/__init__.py | 0 .../sidecar_wrappers/base_sidecar_wrapper.py | 56 ------ .../conv1d_sidecar_wrapper.py | 11 -- .../conv2d_sidecar_wrapper.py | 11 -- .../flux_rms_norm_sidecar_wrapper.py | 24 --- .../linear_sidecar_wrapper.py | 66 ------- .../backend/patches/sidecar_wrappers/utils.py | 20 -- .../test_flux_rms_norm_sidecar_wrapper.py | 23 --- .../test_linear_sidecar_wrapper.py | 182 ------------------ 9 files changed, 393 deletions(-) delete mode 100644 invokeai/backend/patches/sidecar_wrappers/__init__.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py delete mode 100644 invokeai/backend/patches/sidecar_wrappers/utils.py delete mode 100644 tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py delete mode 100644 tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py diff --git a/invokeai/backend/patches/sidecar_wrappers/__init__.py b/invokeai/backend/patches/sidecar_wrappers/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py deleted file mode 100644 index 46d69bbe915..00000000000 --- a/invokeai/backend/patches/sidecar_wrappers/base_sidecar_wrapper.py +++ /dev/null @@ -1,56 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch - - -class BaseSidecarWrapper(torch.nn.Module): - """A base class for sidecar wrappers. - - A sidecar wrapper is a wrapper for an existing torch.nn.Module that applies a - list of patches as 'sidecar' patches. I.e. it applies the sidecar patches during forward inference without modifying - the original module. - - Sidecar wrappers are typically used over regular patches when: - - The original module is quantized and so the weights can't be patched in the usual way. - - The original module is on the CPU and modifying the weights would require backing up the original weights and - doubling the CPU memory usage. - """ - - def __init__( - self, orig_module: torch.nn.Module, patches_and_weights: list[tuple[BaseLayerPatch, float]] | None = None - ): - super().__init__() - self._orig_module = orig_module - self._patches_and_weights = [] if patches_and_weights is None else patches_and_weights - - @property - def orig_module(self) -> torch.nn.Module: - return self._orig_module - - def add_patch(self, patch: BaseLayerPatch, patch_weight: float): - """Add a patch to the sidecar wrapper.""" - self._patches_and_weights.append((patch, patch_weight)) - - def _aggregate_patch_parameters( - self, patches_and_weights: list[tuple[BaseLayerPatch, float]] - ) -> dict[str, torch.Tensor]: - """Helper function that aggregates the parameters from all patches into a single dict.""" - params: dict[str, torch.Tensor] = {} - - for patch, patch_weight in patches_and_weights: - # TODO(ryand): self._orig_module could be quantized. Depending on what the patch is doing with the original - # parameters, this might fail or return incorrect results. - layer_params = patch.get_parameters( - dict(self._orig_module.named_parameters(recurse=False)), weight=patch_weight - ) - - for param_name, param_weight in layer_params.items(): - if param_name not in params: - params[param_name] = param_weight - else: - params[param_name] += param_weight - - return params - - def forward(self, *args, **kwargs): # type: ignore - raise NotImplementedError() diff --git a/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py deleted file mode 100644 index 7877aae8c75..00000000000 --- a/invokeai/backend/patches/sidecar_wrappers/conv1d_sidecar_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class Conv1dSidecarWrapper(BaseSidecarWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - return self.orig_module(input) + torch.nn.functional.conv1d( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) diff --git a/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py deleted file mode 100644 index d9bb7135348..00000000000 --- a/invokeai/backend/patches/sidecar_wrappers/conv2d_sidecar_wrapper.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch - -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class Conv2dSidecarWrapper(BaseSidecarWrapper): - def forward(self, input: torch.Tensor) -> torch.Tensor: - aggregated_param_residuals = self._aggregate_patch_parameters(self._patches_and_weights) - return self.orig_module(input) + torch.nn.functional.conv1d( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) diff --git a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py deleted file mode 100644 index 34c3b9b3699..00000000000 --- a/invokeai/backend/patches/sidecar_wrappers/flux_rms_norm_sidecar_wrapper.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class FluxRMSNormSidecarWrapper(BaseSidecarWrapper): - """A sidecar wrapper for a FLUX RMSNorm layer. - - This wrapper is a special case. It is added specifically to enable FLUX structural control LoRAs, which overwrite - the RMSNorm scale parameters. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # Given the narrow focus of this wrapper, we only support a very particular patch configuration: - assert len(self._patches_and_weights) == 1 - patch, _patch_weight = self._patches_and_weights[0] - assert isinstance(patch, SetParameterLayer) - assert patch.param_name == "scale" - - # Apply the patch. - # NOTE(ryand): Currently, we ignore the patch weight when running as a sidecar. It's not clear how this should - # be handled. - return torch.nn.functional.rms_norm(input, patch.weight.shape, patch.weight, eps=1e-6) diff --git a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py b/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py deleted file mode 100644 index 98775b9feb8..00000000000 --- a/invokeai/backend/patches/sidecar_wrappers/linear_sidecar_wrapper.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch -from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.sidecar_wrappers.base_sidecar_wrapper import BaseSidecarWrapper - - -class LinearSidecarWrapper(BaseSidecarWrapper): - def _lora_forward(self, input: torch.Tensor, lora_layer: LoRALayer, lora_weight: float) -> torch.Tensor: - """An optimized implementation of the residual calculation for a Linear LoRALayer.""" - x = torch.nn.functional.linear(input, lora_layer.down) - if lora_layer.mid is not None: - x = torch.nn.functional.linear(x, lora_layer.mid) - x = torch.nn.functional.linear(x, lora_layer.up, bias=lora_layer.bias) - x *= lora_weight * lora_layer.scale() - return x - - def _concatenated_lora_forward( - self, input: torch.Tensor, concatenated_lora_layer: ConcatenatedLoRALayer, lora_weight: float - ) -> torch.Tensor: - """An optimized implementation of the residual calculation for a Linear ConcatenatedLoRALayer.""" - x_chunks: list[torch.Tensor] = [] - for lora_layer in concatenated_lora_layer.lora_layers: - x_chunk = torch.nn.functional.linear(input, lora_layer.down) - if lora_layer.mid is not None: - x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.mid) - x_chunk = torch.nn.functional.linear(x_chunk, lora_layer.up, bias=lora_layer.bias) - x_chunk *= lora_weight * lora_layer.scale() - x_chunks.append(x_chunk) - - # TODO(ryand): Generalize to support concat_axis != 0. - assert concatenated_lora_layer.concat_axis == 0 - x = torch.cat(x_chunks, dim=-1) - return x - - def forward(self, input: torch.Tensor) -> torch.Tensor: - # First, apply the original linear layer. - # NOTE: We slice the input to match the original weight shape in order to work with FluxControlLoRAs, which - # change the linear layer's in_features. - orig_input = input - input = orig_input[..., : self.orig_module.in_features] - output = self.orig_module(input) - - # Then, apply layers for which we have optimized implementations. - unprocessed_patches_and_weights: list[tuple[BaseLayerPatch, float]] = [] - for patch, patch_weight in self._patches_and_weights: - if isinstance(patch, FluxControlLoRALayer): - # Note that we use the original input here, not the sliced input. - output += self._lora_forward(orig_input, patch, patch_weight) - elif isinstance(patch, LoRALayer): - output += self._lora_forward(input, patch, patch_weight) - elif isinstance(patch, ConcatenatedLoRALayer): - output += self._concatenated_lora_forward(input, patch, patch_weight) - else: - unprocessed_patches_and_weights.append((patch, patch_weight)) - - # Finally, apply any remaining patches. - if len(unprocessed_patches_and_weights) > 0: - aggregated_param_residuals = self._aggregate_patch_parameters(unprocessed_patches_and_weights) - output += torch.nn.functional.linear( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) - - return output diff --git a/invokeai/backend/patches/sidecar_wrappers/utils.py b/invokeai/backend/patches/sidecar_wrappers/utils.py deleted file mode 100644 index 6a71213b09a..00000000000 --- a/invokeai/backend/patches/sidecar_wrappers/utils.py +++ /dev/null @@ -1,20 +0,0 @@ -import torch - -from invokeai.backend.flux.modules.layers import RMSNorm -from invokeai.backend.patches.sidecar_wrappers.conv1d_sidecar_wrapper import Conv1dSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.conv2d_sidecar_wrapper import Conv2dSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper -from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper - - -def wrap_module_with_sidecar_wrapper(orig_module: torch.nn.Module) -> torch.nn.Module: - if isinstance(orig_module, torch.nn.Linear): - return LinearSidecarWrapper(orig_module) - elif isinstance(orig_module, torch.nn.Conv1d): - return Conv1dSidecarWrapper(orig_module) - elif isinstance(orig_module, torch.nn.Conv2d): - return Conv2dSidecarWrapper(orig_module) - elif isinstance(orig_module, RMSNorm): - return FluxRMSNormSidecarWrapper(orig_module) - else: - raise ValueError(f"No sidecar wrapper found for module type: {type(orig_module)}") diff --git a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py deleted file mode 100644 index ee0dce554f4..00000000000 --- a/tests/backend/patches/sidecar_wrappers/test_flux_rms_norm_sidecar_wrapper.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch - -from invokeai.backend.patches.layers.set_parameter_layer import SetParameterLayer -from invokeai.backend.patches.sidecar_wrappers.flux_rms_norm_sidecar_wrapper import FluxRMSNormSidecarWrapper - - -def test_flux_rms_norm_sidecar_wrapper(): - # Create a RMSNorm layer. - dim = 10 - rms_norm = torch.nn.RMSNorm(dim) - - # Create a SetParameterLayer. - new_scale = torch.randn(dim) - set_parameter_layer = SetParameterLayer("scale", new_scale) - - # Create a FluxRMSNormSidecarWrapper. - rms_norm_wrapped = FluxRMSNormSidecarWrapper(rms_norm, [(set_parameter_layer, 1.0)]) - - # Run the FluxRMSNormSidecarWrapper. - input = torch.randn(1, dim) - expected_output = torch.nn.functional.rms_norm(input, new_scale.shape, new_scale, eps=1e-6) - output_wrapped = rms_norm_wrapped(input) - assert torch.allclose(output_wrapped, expected_output, atol=1e-6) diff --git a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py b/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py deleted file mode 100644 index 607f364dcd6..00000000000 --- a/tests/backend/patches/sidecar_wrappers/test_linear_sidecar_wrapper.py +++ /dev/null @@ -1,182 +0,0 @@ -import copy - -import torch - -from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.flux_control_lora_layer import FluxControlLoRALayer -from invokeai.backend.patches.layers.full_layer import FullLayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer -from invokeai.backend.patches.pad_with_zeros import pad_with_zeros -from invokeai.backend.patches.sidecar_wrappers.linear_sidecar_wrapper import LinearSidecarWrapper - - -@torch.no_grad() -def test_linear_sidecar_wrapper_lora(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create a LoRA layer. - rank = 4 - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - lora_layer = LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias) - - # Patch the LoRA layer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) - - # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -@torch.no_grad() -def test_linear_sidecar_wrapper_multiple_loras(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create two LoRA layers. - rank = 4 - lora_layer = LoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - lora_layer_2 = LoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - # We use different weights for the two LoRA layers to ensure this is working. - lora_weight = 1.0 - lora_weight_2 = 0.5 - - # Patch the LoRA layers into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * (lora_layer.scale() * lora_weight) - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * (lora_layer.scale() * lora_weight) - linear_patched.weight.data += lora_layer_2.get_weight(linear_patched.weight) * ( - lora_layer_2.scale() * lora_weight_2 - ) - linear_patched.bias.data += lora_layer_2.get_bias(linear_patched.bias) * (lora_layer_2.scale() * lora_weight_2) - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, lora_weight), (lora_layer_2, lora_weight_2)]) - - # Run the LoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -@torch.no_grad() -def test_linear_sidecar_wrapper_concatenated_lora(): - # Create a linear layer. - in_features = 5 - sub_layer_out_features = [5, 10, 15] - linear = torch.nn.Linear(in_features, sum(sub_layer_out_features)) - - # Create a ConcatenatedLoRA layer. - rank = 4 - sub_layers: list[LoRALayer] = [] - for out_features in sub_layer_out_features: - down = torch.randn(rank, in_features) - up = torch.randn(out_features, rank) - bias = torch.randn(out_features) - sub_layers.append(LoRALayer(up=up, mid=None, down=down, alpha=1.0, bias=bias)) - concatenated_lora_layer = ConcatenatedLoRALayer(sub_layers, concat_axis=0) - - # Patch the ConcatenatedLoRA layer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += ( - concatenated_lora_layer.get_weight(linear_patched.weight) * concatenated_lora_layer.scale() - ) - linear_patched.bias.data += concatenated_lora_layer.get_bias(linear_patched.bias) * concatenated_lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(concatenated_lora_layer, 1.0)]) - - # Run the ConcatenatedLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -def test_linear_sidecar_wrapper_full_layer(): - # Create a linear layer. - in_features = 10 - out_features = 20 - linear = torch.nn.Linear(in_features, out_features) - - # Create a FullLayer. - full_layer = FullLayer(weight=torch.randn(out_features, in_features), bias=torch.randn(out_features)) - - # Patch the FullLayer into the linear layer. - linear_patched = copy.deepcopy(linear) - linear_patched.weight.data += full_layer.get_weight(linear_patched.weight) - linear_patched.bias.data += full_layer.get_bias(linear_patched.bias) - - # Create a LinearSidecarWrapper. - full_wrapped = LinearSidecarWrapper(linear, [(full_layer, 1.0)]) - - # Run the FullLayer-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, in_features) - output_patched = linear_patched(input) - output_wrapped = full_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) - - -def test_linear_sidecar_wrapper_flux_control_lora_layer(): - # Create a linear layer. - orig_in_features = 10 - out_features = 40 - linear = torch.nn.Linear(orig_in_features, out_features) - - # Create a FluxControlLoRALayer. - patched_in_features = 20 - rank = 4 - lora_layer = FluxControlLoRALayer( - up=torch.randn(out_features, rank), - mid=None, - down=torch.randn(rank, patched_in_features), - alpha=1.0, - bias=torch.randn(out_features), - ) - - # Patch the FluxControlLoRALayer into the linear layer. - linear_patched = copy.deepcopy(linear) - # Expand the existing weight. - expanded_weight = pad_with_zeros(linear_patched.weight, torch.Size([out_features, patched_in_features])) - linear_patched.weight = torch.nn.Parameter(expanded_weight, requires_grad=linear_patched.weight.requires_grad) - # Expand the existing bias. - expanded_bias = pad_with_zeros(linear_patched.bias, torch.Size([out_features])) - linear_patched.bias = torch.nn.Parameter(expanded_bias, requires_grad=linear_patched.bias.requires_grad) - # Add the residuals. - linear_patched.weight.data += lora_layer.get_weight(linear_patched.weight) * lora_layer.scale() - linear_patched.bias.data += lora_layer.get_bias(linear_patched.bias) * lora_layer.scale() - - # Create a LinearSidecarWrapper. - lora_wrapped = LinearSidecarWrapper(linear, [(lora_layer, 1.0)]) - - # Run the FluxControlLoRA-patched linear layer and the LinearSidecarWrapper and assert they are equal. - input = torch.randn(1, patched_in_features) - output_patched = linear_patched(input) - output_wrapped = lora_wrapped(input) - assert torch.allclose(output_patched, output_wrapped, atol=1e-6) From 8b4b0ff0cfbf3bcdd7193ca651209d7b65b2ead4 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 19:00:24 +0000 Subject: [PATCH 29/31] Fix bug in CustomConv1d and CustomConv2d patch calculations. --- .../custom_modules/custom_conv1d.py | 10 +++++++--- .../custom_modules/custom_conv2d.py | 10 +++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py index b59b5a2aae5..e65b3259246 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv1d.py @@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomConv1d(torch.nn.Conv1d, CustomModuleMixin): @@ -21,9 +24,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: orig_params=orig_params, device=input.device, ) - return self._conv_forward( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) + + weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) diff --git a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py index 1077b47ed5e..91f08fb96be 100644 --- a/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py +++ b/invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/custom_conv2d.py @@ -4,6 +4,9 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_module_mixin import ( CustomModuleMixin, ) +from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.utils import ( + add_nullable_tensors, +) class CustomConv2d(torch.nn.Conv2d, CustomModuleMixin): @@ -21,9 +24,10 @@ def _autocast_forward_with_patches(self, input: torch.Tensor) -> torch.Tensor: orig_params=orig_params, device=input.device, ) - return self._conv_forward( - input, aggregated_param_residuals["weight"], aggregated_param_residuals.get("bias", None) - ) + + weight = add_nullable_tensors(weight, aggregated_param_residuals.get("weight", None)) + bias = add_nullable_tensors(bias, aggregated_param_residuals.get("bias", None)) + return self._conv_forward(input, weight, bias) def _autocast_forward(self, input: torch.Tensor) -> torch.Tensor: weight = cast_to_device(self.weight, input.device) From 477d87ec31b2f8234f28cb9ab52f53b4214c0a10 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Sun, 29 Dec 2024 21:48:51 +0000 Subject: [PATCH 30/31] Fix layer patch dtype selection for CLIP text encoder models. --- invokeai/app/invocations/compel.py | 4 ++-- invokeai/app/invocations/flux_text_encoder.py | 3 +-- invokeai/app/invocations/sd3_text_encoder.py | 3 +-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 92d7f4638c0..b535254cfd4 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -86,7 +86,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: model=text_encoder, patches=_lora_loader(), prefix="lora_te_", - dtype=TorchDevice.choose_torch_dtype(), + dtype=text_encoder.dtype, cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. @@ -184,7 +184,7 @@ def _lora_loader() -> Iterator[Tuple[ModelPatchRaw, float]]: model=text_encoder, patches=_lora_loader(), prefix=lora_prefix, - dtype=TorchDevice.choose_torch_dtype(), + dtype=text_encoder.dtype, cached_weights=cached_weights, ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index c3a752ab30d..3f1f38c4a1f 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -22,7 +22,6 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo -from invokeai.backend.util.devices import TorchDevice @invocation( @@ -116,7 +115,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor: model=clip_text_encoder, patches=self._clip_lora_iterator(context), prefix=FLUX_LORA_CLIP_PREFIX, - dtype=TorchDevice.choose_torch_dtype(), + dtype=clip_text_encoder.dtype, cached_weights=cached_weights, ) ) diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 9103dbbb41f..6569fa0a762 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -21,7 +21,6 @@ from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_CLIP_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, SD3ConditioningInfo -from invokeai.backend.util.devices import TorchDevice # The SD3 T5 Max Sequence Length set based on the default in diffusers. SD3_T5_MAX_SEQ_LEN = 256 @@ -155,7 +154,7 @@ def _clip_encode( model=clip_text_encoder, patches=self._clip_lora_iterator(context, clip_model), prefix=FLUX_LORA_CLIP_PREFIX, - dtype=TorchDevice.choose_torch_dtype(), + dtype=clip_text_encoder.dtype, cached_weights=cached_weights, ) ) From 9a0a226ce11015f23c4effb4e48b55e72c32b2f8 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 30 Dec 2024 10:41:48 -0500 Subject: [PATCH 31/31] Fix bitsandbytes imports in unit tests on MacOS. --- .../custom_modules/test_custom_invoke_linear_nf4.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py index 3559ddea6cb..f97404fb949 100644 --- a/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py +++ b/tests/backend/model_manager/load/model_cache/torch_module_autocast/custom_modules/test_custom_invoke_linear_nf4.py @@ -1,13 +1,17 @@ import pytest import torch -from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( - CustomInvokeLinearNF4, -) from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import ( wrap_custom_layer, ) -from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 + +if not torch.cuda.is_available(): + pytest.skip("CUDA is not available", allow_module_level=True) +else: + from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_modules.custom_invoke_linear_nf4 import ( + CustomInvokeLinearNF4, + ) + from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4 def build_linear_nf4_layer(orig_layer: torch.nn.Linear | None = None):