diff --git a/modules/stage_apply_loras.py b/modules/stage_apply_loras.py index 6f7352e..b009d62 100644 --- a/modules/stage_apply_loras.py +++ b/modules/stage_apply_loras.py @@ -57,15 +57,21 @@ def process(self, data, stage_input): base_model = access.get_from_pipeline(Names.P_BASE_MODEL) base_clip = access.get_from_pipeline(Names.P_BASE_CLIP) + refiner_model = access.get_from_pipeline(Names.P_REFINER_MODEL) + refiner_clip = access.get_from_pipeline(Names.P_REFINER_CLIP) base_model_changed = access.changed_in_pipeline(Names.P_BASE_MODEL) base_clip_changed = access.changed_in_pipeline(Names.P_BASE_CLIP) + refiner_model_changed = access.changed_in_pipeline(Names.P_REFINER_MODEL) + refiner_clip_changed = access.changed_in_pipeline(Names.P_REFINER_CLIP) lora_stack = access.get_active_setting(UI.S_LORAS, UI.F_LORA_STACK, []) any_changes = ( base_model_changed or - base_clip_changed + base_clip_changed or + refiner_model_changed or + refiner_clip_changed ) applied_loras = [] @@ -82,15 +88,22 @@ def process(self, data, stage_input): if lora_name is not None and lora_name != UI.NONE and lora_strength != 0.0: (base_model, base_clip) = NodeWrapper.lora_loader.load_lora(base_model, base_clip, lora_name, lora_strength, lora_strength) + if refiner_model is not None and refiner_clip is not None: + (refiner_model, refiner_clip) = NodeWrapper.lora_loader.load_lora(refiner_model, refiner_clip, lora_name, + lora_strength, lora_strength) applied_loras.append(lora_name) - access.update_in_cache(Names.C_APPLIED_LORAS, lora_stack, (base_model, base_clip)) + access.update_in_cache(Names.C_APPLIED_LORAS, lora_stack, (base_model, base_clip, refiner_model, refiner_clip)) access.update_in_pipeline(Names.P_BASE_MODEL, base_model) access.update_in_pipeline(Names.P_BASE_CLIP, base_clip) + access.update_in_pipeline(Names.P_REFINER_MODEL, refiner_model) + access.update_in_pipeline(Names.P_REFINER_CLIP, refiner_clip) else: - (base_model, base_clip) = access.get_from_cache(Names.C_APPLIED_LORAS) + (base_model, base_clip, refiner_model, refiner_clip) = access.get_from_cache(Names.C_APPLIED_LORAS) access.restore_in_pipeline(Names.P_BASE_MODEL, base_model) access.restore_in_pipeline(Names.P_BASE_CLIP, base_clip) + access.restore_in_pipeline(Names.P_REFINER_MODEL, refiner_model) + access.restore_in_pipeline(Names.P_REFINER_CLIP, refiner_clip) loaded_loras = { Names.F_LORA_NAMES: applied_loras,