diff --git a/onediff_comfy_nodes/_nodes.py b/onediff_comfy_nodes/_nodes.py index ccf695913..dda99a4d3 100644 --- a/onediff_comfy_nodes/_nodes.py +++ b/onediff_comfy_nodes/_nodes.py @@ -92,6 +92,9 @@ def INPUT_TYPES(s): } RETURN_TYPES = ("VAE",) + + def speedup(self, vae, inplace=False, custom_booster: BoosterScheduler = None): + return super().speedup(vae, inplace, custom_booster) class ControlnetSpeedup: diff --git a/onediff_comfy_nodes/modules/booster_cache.py b/onediff_comfy_nodes/modules/booster_cache.py index 392dc5f41..f36fe47d3 100644 --- a/onediff_comfy_nodes/modules/booster_cache.py +++ b/onediff_comfy_nodes/modules/booster_cache.py @@ -1,22 +1,57 @@ import torch import traceback from collections import OrderedDict +from functools import singledispatch +from comfy.controlnet import ControlLora, ControlNet from comfy.model_patcher import ModelPatcher from comfy.sd import VAE from onediff.torch_utils.module_operations import get_sub_module from onediff.utils.import_utils import is_oneflow_available +from .._config import is_disable_oneflow_backend -if is_oneflow_available(): + +if not is_disable_oneflow_backend() and is_oneflow_available(): from .oneflow.utils.booster_utils import is_using_oneflow_backend -def switch_to_cached_model(new_model: ModelPatcher, cache_model): - assert type(new_model.model) == type(cache_model) - for k, v in new_model.model.state_dict().items(): +@singledispatch +def get_target_model(model): + raise NotImplementedError(f"{type(model)=} cache is not supported.") + +@get_target_model.register(ModelPatcher) +def _(model): + return model.model + +@get_target_model.register(VAE) +def _(model): + return model.first_stage_model + +@get_target_model.register(ControlNet) +def _(model): + return model.control_model + +@get_target_model.register(ControlLora) +def _(model): + return model + + +def switch_to_cached_model(new_model, cache_model): + target_model = get_target_model(new_model) + assert type(target_model) == type(cache_model) + for k, v in target_model.state_dict().items(): cached_v: torch.Tensor = get_sub_module(cache_model, k) assert v.dtype == cached_v.dtype cached_v.copy_(v) - new_model.model = cache_model + if isinstance(new_model, ModelPatcher): + new_model.model = cache_model + elif isinstance(new_model, VAE): + new_model.first_stage_model = cache_model + elif isinstance(new_model, ControlNet): + new_model.control_model = cache_model + elif isinstance(new_model, ControlLora): + new_model = cache_model + else: + raise NotImplementedError(f"{type(new_model)=} cache is not supported.") return new_model @@ -27,9 +62,9 @@ def put(self, key, model): if key is None: return # oneflow backends output image error - if is_oneflow_available() and is_using_oneflow_backend(model): + if not is_disable_oneflow_backend() and is_oneflow_available() and is_using_oneflow_backend(model): return - self._cache[key] = model.model + self._cache[key] = get_target_model(model) def get(self, key, default=None): return self._cache.get(key, default)