From 057a91c265263c708789f877bf3b6f76ff007bdc Mon Sep 17 00:00:00 2001 From: nono-Sang <1908865287@qq.com> Date: Mon, 17 Jun 2024 13:25:54 +0800 Subject: [PATCH 1/2] Fix VaeSpeedup node and Support multiple model cache --- onediff_comfy_nodes/_nodes.py | 3 ++ onediff_comfy_nodes/modules/booster_cache.py | 40 ++++++++++++++++---- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index ccf695913..dda99a4d3 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -92,6 +92,9 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("VAE",) + + def speedup(self, vae, inplace=False, custom_booster: BoosterScheduler = None): + return super().speedup(vae, inplace, custom_booster) class ControlnetSpeedup: diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index 392dc5f41..62c095e36 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -1,22 +1,48 @@ import torch import traceback from collections import OrderedDict +from functools import singledispatch +from comfy.controlnet import ControlLora, ControlNet from comfy.model_patcher import ModelPatcher from comfy.sd import VAE from onediff.torch_utils.module_operations import get_sub_module from onediff.utils.import_utils import is_oneflow_available +from .._config import is_disable_oneflow_backend -if is_oneflow_available(): + +if not is_disable_oneflow_backend() and is_oneflow_available(): from .oneflow.utils.booster_utils import is_using_oneflow_backend -def switch_to_cached_model(new_model: ModelPatcher, cache_model): - assert type(new_model.model) == type(cache_model) - for k, v in new_model.model.state_dict().items(): +@singledispatch +def get_target_model(model): + raise NotImplementedError(f"{type(model)=} cache is not supported.") + +@get_target_model.register(ModelPatcher) +def _(model): + return model.model + +@get_target_model.register(VAE) +def _(model): + return model.first_stage_model + +@get_target_model.register(ControlNet) +def _(model): + return model.control_model + +@get_target_model.register(ControlLora) +def _(model): + return model + + +def switch_to_cached_model(new_model, cache_model): + target_model = get_target_model(new_model) + assert type(target_model) == type(cache_model) + for k, v in target_model.state_dict().items(): cached_v: torch.Tensor = get_sub_module(cache_model, k) assert v.dtype == cached_v.dtype cached_v.copy_(v) - new_model.model = cache_model + target_model = cache_model return new_model @@ -27,9 +53,9 @@ def put(self, key, model): if key is None: return # oneflow backends output image error - if is_oneflow_available() and is_using_oneflow_backend(model): + if not is_disable_oneflow_backend() and is_oneflow_available() and is_using_oneflow_backend(model): return - self._cache[key] = model.model + self._cache[key] = get_target_model(model) def get(self, key, default=None): return self._cache.get(key, default) From f22dbbe7ab42ff15a6b1fa47be6d96a6a5cd44a5 Mon Sep 17 00:00:00 2001 From: nono-Sang <1908865287@qq.com> Date: Mon, 17 Jun 2024 14:35:15 +0800 Subject: [PATCH 2/2] fix model switch --- onediff_comfy_nodes/modules/booster_cache.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index 62c095e36..f36fe47d3 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -42,7 +42,16 @@ def switch_to_cached_model(new_model, cache_model): cached_v: torch.Tensor = get_sub_module(cache_model, k) assert v.dtype == cached_v.dtype cached_v.copy_(v) - target_model = cache_model + if isinstance(new_model, ModelPatcher): + new_model.model = cache_model + elif isinstance(new_model, VAE): + new_model.first_stage_model = cache_model + elif isinstance(new_model, ControlNet): + new_model.control_model = cache_model + elif isinstance(new_model, ControlLora): + new_model = cache_model + else: + raise NotImplementedError(f"{type(new_model)=} cache is not supported.") return new_model