From 0c8044070a7de4f73b11c6cd143b28282aa3e939 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 13 Jan 2025 13:00:30 -0500 Subject: [PATCH] refactor: split legacy loaders Signed-off-by: Vladimir Mandic --- CHANGELOG.md | 7 +- extensions-builtin/sd-webui-agent-scheduler | 2 +- installer.py | 5 + modules/cmd_args.py | 1 - modules/control/run.py | 75 ++-- modules/control/units/controlnet.py | 3 +- modules/processing_original.py | 13 + modules/sd_detect.py | 6 - modules/sd_models.py | 399 ++------------------ modules/sd_models_legacy.py | 207 ++++++++++ modules/sd_models_utils.py | 151 ++++++++ modules/sd_offload.py | 25 +- modules/shared.py | 15 +- modules/ui_control_helpers.py | 12 +- scripts/pulid_ext.py | 15 +- webui.py | 127 +++---- wiki | 2 +- 17 files changed, 543 insertions(+), 522 deletions(-) create mode 100644 modules/sd_models_legacy.py create mode 100644 modules/sd_models_utils.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 60e21ff6d..216f06d3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,8 @@ - refactored progress monitoring, job updates and live preview - improved metadata save and restore - startup tracing and optimizations - - threading load locks on model loads + - threading load locks on model loads + - refactor native vs legacy model loader - **Schedulers**: - [TDD](https://github.com/RedAIGC/Target-Driven-Distillation) new super-fast scheduler that can generate images in 4-8 steps recommended to use with [TDD LoRA](https://huggingface.co/RED-AIGC/TDD/tree/main) @@ -40,7 +41,7 @@ - **XYZ Grid**: add prompt search&replace options: *primary, refine, detailer, all* - **SysInfo**: update to collected data and benchmarks - [Wiki/Docs](https://vladmandic.github.io/sdnext-docs/): - - updated: Detailer, Install, Debug, Control-HowTo, ZLUDA + - updated: Detailer, Install, Update, Debug, Control-HowTo, ZLUDA - **Fixes**: - explict clear caches on model load - lock adetailer commit: `#a89c01d` @@ -61,6 +62,8 @@ - restore args after batch run - flux controlnet - zluda installer + - control inherit parent pipe settings + - control logging ## Update for 2024-12-31 diff --git a/extensions-builtin/sd-webui-agent-scheduler b/extensions-builtin/sd-webui-agent-scheduler index cd878626f..a33753321 160000 --- a/extensions-builtin/sd-webui-agent-scheduler +++ b/extensions-builtin/sd-webui-agent-scheduler @@ -1 +1 @@ -Subproject commit cd878626f3b4f9a0c7c45c7d70b73a6168f612a4 +Subproject commit a33753321b914c6122df96d1dc0b5117d38af680 diff --git a/installer.py b/installer.py index 9ad03a862..49bcf8b77 100644 --- a/installer.py +++ b/installer.py @@ -150,6 +150,11 @@ def get(self): log.addHandler(rb) log.buffer = rb.buffer + def quiet_log(quiet: bool=False, *args, **kwargs): # pylint: disable=redefined-outer-name,keyword-arg-before-vararg + if not quiet: + log.debug(*args, **kwargs) + log.quiet = quiet_log + # overrides logging.getLogger("urllib3").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 54b888a17..755c31c97 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -78,7 +78,6 @@ def compatibility_args(): group_compat.add_argument("--disable-queue", default=os.environ.get("SD_DISABLEQUEUE", False), action='store_true', help=argparse.SUPPRESS) - def settings_args(opts, args): # removed args are added here as hidden in fixed format for compatbility reasons group_compat = parser.add_argument_group('Compatibility options') diff --git a/modules/control/run.py b/modules/control/run.py index 03c13240e..21fc76a4b 100644 --- a/modules/control/run.py +++ b/modules/control/run.py @@ -17,10 +17,11 @@ from modules.processing_class import StableDiffusionProcessingControl from modules.ui_common import infotext_to_html from modules.api import script +from modules.timer import process as process_timer -debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: CONTROL') +debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None +debug_log = shared.log.trace if debug else lambda *args, **kwargs: None pipe = None instance = None original_pipeline = None @@ -32,7 +33,7 @@ def restore_pipeline(): if instance is not None and hasattr(instance, 'restore'): instance.restore() if original_pipeline is not None and (original_pipeline.__class__.__name__ != shared.sd_model.__class__.__name__): - debug(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__}') + debug_log(f'Control restored pipeline: class={shared.sd_model.__class__.__name__} to={original_pipeline.__class__.__name__}') shared.sd_model = original_pipeline pipe = None instance = None @@ -109,7 +110,7 @@ def set_pipe(p, has_models, unit_type, selected_models, active_model, active_str p.strength = active_strength[0] pipe = shared.sd_model instance = None - debug(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') + debug_log(f'Control: run type={unit_type} models={has_models} pipe={pipe.__class__.__name__ if pipe is not None else None}') return pipe @@ -124,14 +125,14 @@ def check_active(p, unit_type, units): if u.type != unit_type: continue num_units += 1 - debug(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}') + debug_log(f'Control unit: i={num_units} type={u.type} enabled={u.enabled}') if not u.enabled: if u.controlnet is not None and u.controlnet.model is not None: - debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') + debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.cpu}') sd_models.move_model(u.controlnet.model, devices.cpu) continue if u.controlnet is not None and u.controlnet.model is not None: - debug(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') + debug_log(f'Control unit offload: model="{u.controlnet.model_id}" device={devices.device}') sd_models.move_model(u.controlnet.model, devices.device) if unit_type == 't2i adapter' and u.adapter.model is not None: active_process.append(u.process) @@ -176,7 +177,7 @@ def check_active(p, unit_type, units): active_process.append(u.process) shared.log.debug(f'Control process unit: i={num_units} process={u.process.processor_id}') active_strength.append(float(u.strength)) - debug(f'Control active: process={len(active_process)} model={len(active_model)}') + debug_log(f'Control active: process={len(active_process)} model={len(active_model)}') return active_process, active_model, active_strength, active_start, active_end @@ -213,7 +214,7 @@ def control_set(kwargs): if kwargs: global p_extra_args # pylint: disable=global-statement p_extra_args = {} - debug(f'Control extra args: {kwargs}') + debug_log(f'Control extra args: {kwargs}') for k, v in kwargs.items(): p_extra_args[k] = v @@ -254,7 +255,7 @@ def control_run(state: str = '', u.process.override = u.override global pipe, original_pipeline # pylint: disable=global-statement - debug(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}') + debug_log(f'Control: type={unit_type} input={inputs} init={inits} type={input_type}') if inputs is None or (type(inputs) is list and len(inputs) == 0): inputs = [None] output_images: List[Image.Image] = [] # output images @@ -402,7 +403,7 @@ def control_run(state: str = '', p.is_tile = p.is_tile and has_models pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) - debug(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}') + debug_log(f'Control pipeline: class={pipe.__class__.__name__} args={vars(p)}') t1, t2, t3 = time.time(), 0, 0 status = True frame = None @@ -420,7 +421,7 @@ def control_run(state: str = '', shared.sd_model = pipe sd_models.move_model(shared.sd_model, shared.device) shared.sd_model.to(dtype=devices.dtype) - debug(f'Control device={devices.device} dtype={devices.dtype}') + debug_log(f'Control device={devices.device} dtype={devices.dtype}') sd_models.copy_diffuser_options(shared.sd_model, original_pipeline) # copy options from original pipeline sd_models.set_diffuser_options(shared.sd_model) else: @@ -458,12 +459,12 @@ def control_run(state: str = '', while status: if pipe is None: # pipe may have been reset externally pipe = set_pipe(p, has_models, unit_type, selected_models, active_model, active_strength, control_conditioning, control_guidance_start, control_guidance_end, inits) - debug(f'Control pipeline reinit: class={pipe.__class__.__name__}') + debug_log(f'Control pipeline reinit: class={pipe.__class__.__name__}') processed_image = None if frame is not None: inputs = [Image.fromarray(frame)] # cv2 to pil for i, input_image in enumerate(inputs): - debug(f'Control Control image: {i + 1} of {len(inputs)}') + debug_log(f'Control Control image: {i + 1} of {len(inputs)}') if shared.state.skipped: shared.state.skipped = False continue @@ -481,20 +482,20 @@ def control_run(state: str = '', continue # match init input if input_type == 1: - debug('Control Init image: same as control') + debug_log('Control Init image: same as control') init_image = input_image elif inits is None: - debug('Control Init image: none') + debug_log('Control Init image: none') init_image = None elif isinstance(inits[i], str): - debug(f'Control: init image: {inits[i]}') + debug_log(f'Control: init image: {inits[i]}') try: init_image = Image.open(inits[i]) except Exception as e: shared.log.error(f'Control: image open failed: path={inits[i]} type=init error={e}') continue else: - debug(f'Control Init image: {i % len(inits) + 1} of {len(inits)}') + debug_log(f'Control Init image: {i % len(inits) + 1} of {len(inits)}') init_image = inits[i % len(inits)] if video is not None and index % (video_skip_frames + 1) != 0: index += 1 @@ -507,18 +508,18 @@ def control_run(state: str = '', width_before, height_before = int(input_image.width * scale_by_before), int(input_image.height * scale_by_before) if input_image is not None: p.extra_generation_params["Control resize"] = f'{resize_name_before}' - debug(f'Control resize: op=before image={input_image} width={width_before} height={height_before} mode={resize_mode_before} name={resize_name_before} context="{resize_context_before}"') + debug_log(f'Control resize: op=before image={input_image} width={width_before} height={height_before} mode={resize_mode_before} name={resize_name_before} context="{resize_context_before}"') input_image = images.resize_image(resize_mode_before, input_image, width_before, height_before, resize_name_before, context=resize_context_before) if input_image is not None and init_image is not None and init_image.size != input_image.size: - debug(f'Control resize init: image={init_image} target={input_image}') + debug_log(f'Control resize init: image={init_image} target={input_image}') init_image = images.resize_image(resize_mode=1, im=init_image, width=input_image.width, height=input_image.height) if input_image is not None and p.override is not None and p.override.size != input_image.size: - debug(f'Control resize override: image={p.override} target={input_image}') + debug_log(f'Control resize override: image={p.override} target={input_image}') p.override = images.resize_image(resize_mode=1, im=p.override, width=input_image.width, height=input_image.height) if input_image is not None: p.width = input_image.width p.height = input_image.height - debug(f'Control: input image={input_image}') + debug_log(f'Control: input image={input_image}') processed_images = [] if mask is not None: @@ -533,7 +534,7 @@ def control_run(state: str = '', else: masked_image = input_image for i, process in enumerate(active_process): # list[image] - debug(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}') + debug_log(f'Control: i={i+1} process="{process.processor_id}" input={masked_image} override={process.override}') processed_image = process( image_input=masked_image, mode='RGB', @@ -548,7 +549,7 @@ def control_run(state: str = '', processors.config[process.processor_id]['dirty'] = True # to force reload process.model = None - debug(f'Control processed: {len(processed_images)}') + debug_log(f'Control processed: {len(processed_images)}') if len(processed_images) > 0: try: if len(p.extra_generation_params["Control process"]) == 0: @@ -574,7 +575,7 @@ def control_run(state: str = '', blended_image = util.blend(blended_image) # blend all processed images into one blended_image = Image.fromarray(blended_image) if isinstance(selected_models, list) and len(processed_images) == len(selected_models): - debug(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}') + debug_log(f'Control: inputs match: input={len(processed_images)} models={len(selected_models)}') p.init_images = processed_images elif isinstance(selected_models, list) and len(processed_images) != len(selected_models): if is_generator: @@ -583,14 +584,14 @@ def control_run(state: str = '', elif selected_models is not None: p.init_images = processed_image else: - debug('Control processed: using input direct') + debug_log('Control processed: using input direct') processed_image = input_image if unit_type == 'reference' and has_models: p.ref_image = p.override or input_image p.task_args.pop('image', None) p.task_args['ref_image'] = p.ref_image - debug(f'Control: process=None image={p.ref_image}') + debug_log(f'Control: process=None image={p.ref_image}') if p.ref_image is None: if is_generator: yield terminate('Attempting reference mode but image is none') @@ -625,7 +626,7 @@ def control_run(state: str = '', if is_generator: image_txt = f'{blended_image.width}x{blended_image.height}' if blended_image is not None else 'None' msg = f'process | {index} of {frames if video is not None else len(inputs)} | {"Image" if video is None else "Frame"} {image_txt}' - debug(f'Control yield: {msg}') + debug_log(f'Control yield: {msg}') if is_generator: yield (None, blended_image, f'Control {msg}') t2 += time.time() - t2 @@ -684,7 +685,7 @@ def control_run(state: str = '', if selected_scale_tab_mask == 1: width_mask, height_mask = int(input_image.width * scale_by_mask), int(input_image.height * scale_by_mask) p.width, p.height = width_mask, height_mask - debug(f'Control resize: op=mask image={mask} width={width_mask} height={height_mask} mode={resize_mode_mask} name={resize_name_mask} context="{resize_context_mask}"') + debug_log(f'Control resize: op=mask image={mask} width={width_mask} height={height_mask} mode={resize_mode_mask} name={resize_name_mask} context="{resize_context_mask}"') # pipeline output = None @@ -693,9 +694,9 @@ def control_run(state: str = '', if not hasattr(pipe, 'restore_pipeline') and video is None: pipe.restore_pipeline = restore_pipeline shared.sd_model.restore_pipeline = restore_pipeline - debug(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}') - # debug(f'Control exec pipeline: p={vars(p)}') - # debug(f'Control exec pipeline: args={p.task_args} image={p.task_args.get("image", None)} control={p.task_args.get("control_image", None)} mask={p.task_args.get("mask_image", None) or p.image_mask} ref={p.task_args.get("ref_image", None)}') + debug_log(f'Control exec pipeline: task={sd_models.get_diffusers_task(pipe)} class={pipe.__class__}') + # debug_log(f'Control exec pipeline: p={vars(p)}') + # debug_log(f'Control exec pipeline: args={p.task_args} image={p.task_args.get("image", None)} control={p.task_args.get("control_image", None)} mask={p.task_args.get("mask_image", None) or p.image_mask} ref={p.task_args.get("ref_image", None)}') if sd_models.get_diffusers_task(pipe) != sd_models.DiffusersTaskType.TEXT_2_IMAGE: # force vae back to gpu if not in txt2img mode sd_models.move_model(pipe.vae, devices.device) @@ -741,7 +742,7 @@ def control_run(state: str = '', width_after = int(output_image.width * scale_by_after) height_after = int(output_image.height * scale_by_after) if resize_mode_after != 0 and resize_name_after != 'None' and not is_grid: - debug(f'Control resize: op=after image={output_image} width={width_after} height={height_after} mode={resize_mode_after} name={resize_name_after} context="{resize_context_after}"') + debug_log(f'Control resize: op=after image={output_image} width={width_after} height={height_after} mode={resize_mode_after} name={resize_name_after} context="{resize_context_after}"') output_image = images.resize_image(resize_mode_after, output_image, width_after, height_after, resize_name_after, context=resize_context_after) output_images.append(output_image) @@ -761,14 +762,16 @@ def control_run(state: str = '', status, frame = video.read() if status: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - debug(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}') + debug_log(f'Control: video frame={index} frames={frames} status={status} skip={index % (video_skip_frames + 1)} progress={index/frames:.2f}') else: status = False if video is not None: video.release() - shared.log.info(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}') + debug_log(f'Control: pipeline units={len(active_model)} process={len(active_process)} time={t3-t0:.2f} init={t1-t0:.2f} proc={t2-t1:.2f} ctrl={t3-t2:.2f} outputs={len(output_images)}') + process_timer.add('init', t1-t0) + process_timer.add('proc', t2-t1) except Exception as e: shared.log.error(f'Control pipeline failed: type={unit_type} units={len(active_model)} error={e}') errors.display(e, 'Control') @@ -789,7 +792,7 @@ def control_run(state: str = '', p.close() restore_pipeline() - debug(f'Ready: {image_txt}') + debug_log(f'Ready: {image_txt}') html_txt = f'

Ready {image_txt}

' if image_txt != '' else '' if len(info_txt) > 0: diff --git a/modules/control/units/controlnet.py b/modules/control/units/controlnet.py index 4390461ee..4837577fe 100644 --- a/modules/control/units/controlnet.py +++ b/modules/control/units/controlnet.py @@ -411,13 +411,14 @@ def __init__(self, if dtype is not None: self.pipeline = self.pipeline.to(dtype) + sd_models.copy_diffuser_options(self.pipeline, pipeline) if opts.diffusers_offload_mode == 'none': sd_models.move_model(self.pipeline, devices.device) from modules.sd_models import set_diffuser_offload set_diffuser_offload(self.pipeline, 'model') t1 = time.time() - log.debug(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') + debug_log(f'Control {what} pipeline: class={self.pipeline.__class__.__name__} time={t1-t0:.2f}') def restore(self): self.pipeline.unload_lora_weights() diff --git a/modules/processing_original.py b/modules/processing_original.py index 5ea255d72..7a0af1b04 100644 --- a/modules/processing_original.py +++ b/modules/processing_original.py @@ -27,6 +27,18 @@ def get_conds_with_caching(function, required_prompts, steps, cache): cache[0] = (required_prompts, steps) return cache[1] +def check_rollback_vae(): + if shared.cmd_opts.rollback_vae: + if not torch.cuda.is_available(): + shared.log.error("Rollback VAE functionality requires compatible GPU") + shared.cmd_opts.rollback_vae = False + elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'): + shared.log.error("Rollback VAE functionality requires Torch 2.1 or higher") + shared.cmd_opts.rollback_vae = False + elif 0 < torch.cuda.get_device_capability()[0] < 8: + shared.log.error('Rollback VAE functionality device capabilities not met') + shared.cmd_opts.rollback_vae = False + def process_original(p: processing.StableDiffusionProcessing): cached_uc = [None, None] @@ -42,6 +54,7 @@ def process_original(p: processing.StableDiffusionProcessing): for x in x_samples_ddim: devices.test_for_nans(x, "vae") except devices.NansException as e: + check_rollback_vae() if not shared.opts.no_half and not shared.opts.no_half_vae and shared.cmd_opts.rollback_vae: shared.log.warning('Tensor with all NaNs was produced in VAE') devices.dtype_vae = torch.bfloat16 diff --git a/modules/sd_detect.py b/modules/sd_detect.py index 15b22c69c..514517c8d 100644 --- a/modules/sd_detect.py +++ b/modules/sd_detect.py @@ -49,12 +49,6 @@ def detect_pipeline(f: str, op: str = 'model', warning=True, quiet=False): elif (size > 20000 and size < 40000): guess = 'FLUX' # guess by name - """ - if 'LCM_' in f.upper() or 'LCM-' in f.upper() or '_LCM' in f.upper() or '-LCM' in f.upper(): - if shared.backend == shared.Backend.ORIGINAL: - warn(f'Model detected as LCM model, but attempting to load using backend=original: {op}={f} size={size} MB') - guess = 'Latent Consistency Model' - """ if 'instaflow' in f.lower(): guess = 'InstaFlow' if 'segmoe' in f.lower(): diff --git a/modules/sd_models.py b/modules/sd_models.py index 1629f2dd7..9d0349a3f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,25 +1,22 @@ -import io import sys import time -import json import copy import inspect import logging -import contextlib import os.path from enum import Enum import diffusers import diffusers.loaders.single_file_utils -from rich import progress # pylint: disable=redefined-builtin import torch -import safetensors.torch -from omegaconf import OmegaConf + from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect from modules.timer import Timer, process as process_timer from modules.memstats import memory_stats from modules.modeldata import model_data from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import -from modules.sd_offload import set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import +from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import +from modules.sd_models_legacy import get_checkpoint_state_dict, load_model_weights, load_model, repair_config # pylint: disable=unused-import +from modules.sd_models_utils import NoWatermark, get_signature, get_call, path_to_repo, patch_diffuser_config, convert_to_faketensors, read_state_dict, get_state_dict_from_checkpoint # pylint: disable=unused-import model_dir = "Stable-diffusion" @@ -35,165 +32,6 @@ checkpoint_tiles = checkpoint_titles # legacy compatibility -class NoWatermark: - def apply_watermark(self, img): - return img - - -def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument - if not os.path.isfile(checkpoint_file): - shared.log.error(f'Load dict: path="{checkpoint_file}" not a file') - return None - try: - pl_sd = None - with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f: - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt: - shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}") - return None - if shared.opts.stream_load: - if extension.lower() == ".safetensors": - # shared.log.debug('Model weights loading: type=safetensors mode=buffered') - buffer = f.read() - pl_sd = safetensors.torch.load(buffer) - else: - # shared.log.debug('Model weights loading: type=checkpoint mode=buffered') - buffer = io.BytesIO(f.read()) - pl_sd = torch.load(buffer, map_location='cpu') - else: - if extension.lower() == ".safetensors": - # shared.log.debug('Model weights loading: type=safetensors mode=mmap') - pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu') - else: - # shared.log.debug('Model weights loading: type=checkpoint mode=direct') - pl_sd = torch.load(f, map_location='cpu') - sd = get_state_dict_from_checkpoint(pl_sd) - del pl_sd - except Exception as e: - errors.display(e, f'Load model: {checkpoint_file}') - sd = None - return sd - - -def get_state_dict_from_checkpoint(pl_sd): - checkpoint_dict_replacements = { - 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', - 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', - 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', - } - - def transform_checkpoint_dict_key(k): - for text, replacement in checkpoint_dict_replacements.items(): - if k.startswith(text): - k = replacement + k[len(text):] - return k - - pl_sd = pl_sd.pop("state_dict", pl_sd) - pl_sd.pop("state_dict", None) - sd = {} - for k, v in pl_sd.items(): - new_key = transform_checkpoint_dict_key(k) - if new_key is not None: - sd[new_key] = v - pl_sd.clear() - pl_sd.update(sd) - return pl_sd - - -def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): - if not os.path.isfile(checkpoint_info.filename): - return None - """ - if checkpoint_info in checkpoints_loaded: - shared.log.info("Load model: cache") - checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache - return checkpoints_loaded[checkpoint_info] - """ - res = read_state_dict(checkpoint_info.filename, what='model') - """ - if shared.opts.sd_checkpoint_cache > 0 and not shared.native: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = res - # clean up cache if limit is reached - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) - """ - timer.record("load") - return res - - -def load_model_weights(model: torch.nn.Module, checkpoint_info: CheckpointInfo, state_dict, timer): - _pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model') - shared.log.debug(f'Load model: memory={memory_stats()}') - timer.record("hash") - if model_data.sd_dict == 'None': - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - if state_dict is None: - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - try: - model.load_state_dict(state_dict, strict=False) - except Exception as e: - shared.log.error(f'Load model: path="{checkpoint_info.filename}"') - shared.log.error(' '.join(str(e).splitlines()[:2])) - return False - del state_dict - timer.record("apply") - if shared.opts.opt_channelslast: - model.to(memory_format=torch.channels_last) - timer.record("channels") - if not shared.opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.opts.upcast_sampling and depth_model: - model.depth_model = None - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model - if shared.opts.cuda_cast_unet: - devices.dtype_unet = model.model.diffusion_model.dtype - else: - model.model.diffusion_model.to(devices.dtype_unet) - model.first_stage_model.to(devices.dtype_vae) - model.sd_model_hash = checkpoint_info.calculate_shorthash() - model.sd_model_checkpoint = checkpoint_info.filename - model.sd_checkpoint_info = checkpoint_info - model.is_sdxl = False # a1111 compatibility item - model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item - model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item - model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training - shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - sd_vae.delete_base_vae() - sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) - sd_vae.load_vae(model, vae_file, vae_source) - timer.record("vae") - return True - - -def repair_config(sd_config): - if "use_ema" not in sd_config.model.params: - sd_config.model.params.use_ema = False - if shared.opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False - if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: - sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" - # For UnCLIP-L, override the hardcoded karlo directory - if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params: - karlo_path = os.path.join(paths.models_path, 'karlo') - sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) - - -sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' -sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' - - def change_backend(): shared.log.info(f'Backend changed: from={shared.backend} to={shared.opts.sd_backend}') shared.log.warning('Full server restart required to apply all changes') @@ -223,21 +61,21 @@ def copy_diffuser_options(new_pipe, orig_pipe): set_accelerate(new_pipe) -def set_vae_options(sd_model, vae = None, op: str = 'model'): +def set_vae_options(sd_model, vae=None, op:str='model', quiet:bool=False): if hasattr(sd_model, "vae"): if vae is not None: sd_model.vae = vae - shared.log.debug(f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"') + shared.log.quiet(quiet, f'Setting {op}: component=VAE name="{sd_vae.loaded_vae_file}"') if shared.opts.diffusers_vae_upcast != 'default': sd_model.vae.config.force_upcast = True if shared.opts.diffusers_vae_upcast == 'true' else False - shared.log.debug(f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}') + shared.log.quiet(quiet, f'Setting {op}: component=VAE upcast={sd_model.vae.config.force_upcast}') if shared.opts.no_half_vae: devices.dtype_vae = torch.float32 sd_model.vae.to(devices.dtype_vae) - shared.log.debug(f'Setting {op}: component=VAE no-half=True') + shared.log.quiet(quiet, f'Setting {op}: component=VAE no-half=True') if hasattr(sd_model, "enable_vae_slicing"): if shared.opts.diffusers_vae_slicing: - shared.log.debug(f'Setting {op}: component=VAE slicing=True') + shared.log.quiet(quiet, f'Setting {op}: component=VAE slicing=True') sd_model.enable_vae_slicing() else: sd_model.disable_vae_slicing() @@ -249,18 +87,18 @@ def set_vae_options(sd_model, vae = None, op: str = 'model'): sd_model.vae.tile_latent_min_size = int(sd_model.vae.config.sample_size / (2 ** (len(sd_model.vae.config.block_out_channels) - 1))) if shared.opts.diffusers_vae_tile_overlap != 0.25: sd_model.vae.tile_overlap_factor = float(shared.opts.diffusers_vae_tile_overlap) - shared.log.debug(f'Setting {op}: component=VAE tiling=True tile={sd_model.vae.tile_sample_min_size} overlap={sd_model.vae.tile_overlap_factor}') + shared.log.quiet(quiet, f'Setting {op}: component=VAE tiling=True tile={sd_model.vae.tile_sample_min_size} overlap={sd_model.vae.tile_overlap_factor}') else: - shared.log.debug(f'Setting {op}: component=VAE tiling=True') + shared.log.quiet(quiet, f'Setting {op}: component=VAE tiling=True') sd_model.enable_vae_tiling() else: sd_model.disable_vae_tiling() if hasattr(sd_model, "vqvae"): - shared.log.debug(f'Setting {op}: component=VQVAE upcast=True') + shared.log.quiet(quiet, f'Setting {op}: component=VQVAE upcast=True') sd_model.vqvae.to(torch.float32) # vqvae is producing nans in fp16 -def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True): +def set_diffuser_options(sd_model, vae=None, op:str='model', offload:bool=True, quiet:bool=False): if sd_model is None: shared.log.warning(f'{op} is not loaded') return @@ -271,19 +109,19 @@ def set_diffuser_options(sd_model, vae = None, op: str = 'model', offload=True): sd_model.has_accelerate = False clear_caches() - set_vae_options(sd_model, vae, op) - set_diffusers_attention(sd_model) + set_vae_options(sd_model, vae, op, quiet) + set_diffusers_attention(sd_model, quiet) if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'fuse_qkv_projections'): try: sd_model.fuse_qkv_projections() - shared.log.debug(f'Setting {op}: fused-qkv=True') + shared.log.quiet(quiet, f'Setting {op}: fused-qkv=True') except Exception as e: shared.log.error(f'Setting {op}: fused-qkv=True {e}') if shared.opts.diffusers_fuse_projections and hasattr(sd_model, 'transformer') and hasattr(sd_model.transformer, 'fuse_qkv_projections'): try: sd_model.transformer.fuse_qkv_projections() - shared.log.debug(f'Setting {op}: fused-qkv=True') + shared.log.quiet(quiet, f'Setting {op}: fused-qkv=True') except Exception as e: shared.log.error(f'Setting {op}: fused-qkv=True {e}') if shared.opts.diffusers_eval: @@ -297,11 +135,11 @@ def eval_model(model, op=None, sd_model=None): # pylint: disable=unused-argument sd_model = sd_models_compile.torchao_quantization(sd_model) if shared.opts.opt_channelslast and hasattr(sd_model, 'unet'): - shared.log.debug(f'Setting {op}: channels-last=True') + shared.log.quiet(quiet, f'Setting {op}: channels-last=True') sd_model.unet.to(memory_format=torch.channels_last) if offload: - set_diffuser_offload(sd_model, op) + set_diffuser_offload(sd_model, op, quiet) def move_model(model, device=None, force=False): @@ -401,50 +239,6 @@ def move_base(model, device): return R -def patch_diffuser_config(sd_model, model_file): - def load_config(fn, k): - model_file = os.path.splitext(fn)[0] - cfg_file = f'{model_file}_{k}.json' - try: - if os.path.exists(cfg_file): - with open(cfg_file, 'r', encoding='utf-8') as f: - return json.load(f) - cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json' - if os.path.exists(cfg_file): - with open(cfg_file, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception: - pass - return {} - - if sd_model is None: - return sd_model - if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower(): - if debug_load: - shared.log.debug('Model config patch: type=inpaint') - sd_model.unet.config.in_channels = 9 - if not hasattr(sd_model, '_internal_dict'): - return sd_model - for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access - component = getattr(sd_model, c, None) - if hasattr(component, 'config'): - if debug_load: - shared.log.debug(f'Model config: component={c} config={component.config}') - override = load_config(model_file, c) - updated = {} - for k, v in override.items(): - if k.startswith('_'): - continue - if v != component.config.get(k, None): - if hasattr(component.config, '__frozen'): - component.config.__frozen = False # pylint: disable=protected-access - component.config[k] = v - updated[k] = v - if updated and debug_load: - shared.log.debug(f'Model config: component={c} override={updated}') - return sd_model - - def load_diffuser_initial(diffusers_load_config, op='model'): sd_model = None checkpoint_info = None @@ -833,18 +627,6 @@ def get_diffusers_task(pipe: diffusers.DiffusionPipeline) -> DiffusersTaskType: return DiffusersTaskType.TEXT_2_IMAGE -def get_signature(cls): - signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True) - return signature.parameters - - -def get_call(cls): - if cls is None: - return [] - signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True) - return signature.parameters - - def switch_pipe(cls: diffusers.DiffusionPipeline, pipeline: diffusers.DiffusionPipeline = None, force = False, args = {}): """ args: @@ -1071,7 +853,7 @@ def set_diffuser_pipe(pipe, new_pipe_type): return pipe -def set_diffusers_attention(pipe): +def set_diffusers_attention(pipe, quiet:bool=False): import diffusers.models.attention_processor as p def set_attn(pipe, attention): @@ -1102,7 +884,7 @@ def set_attn(pipe, attention): if 'ControlNet' in pipe.__class__.__name__: # do not replace attention in ControlNet pipelines return - shared.log.debug(f'Setting model: attention="{shared.opts.cross_attention_optimization}"') + shared.log.quiet(quiet, f'Setting model: attention="{shared.opts.cross_attention_optimization}"') if shared.opts.cross_attention_optimization == "Disabled": pass # do nothing elif shared.opts.cross_attention_optimization == "Scaled-Dot-Product": # The default set by Diffusers @@ -1146,103 +928,6 @@ def get_native(pipe: diffusers.DiffusionPipeline): return size -def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): - from ldm.util import instantiate_from_config - from modules import lowvram, sd_hijack - checkpoint_info = checkpoint_info or select_checkpoint(op=op) - if checkpoint_info is None: - return - if op == 'model' or op == 'dict': - if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model - return - else: - if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model - return - shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}') - if timer is None: - timer = Timer() - current_checkpoint_info = None - if op == 'model' or op == 'dict': - if model_data.sd_model is not None: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) - current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None) - unload_model_weights(op=op) - else: - if model_data.sd_refiner is not None: - sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner) - current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) - unload_model_weights(op=op) - - if not shared.native: - from modules import sd_hijack_inpainting - sd_hijack_inpainting.do_inpainting_hijack() - - if already_loaded_state_dict is not None: - state_dict = already_loaded_state_dict - else: - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - if state_dict is None or checkpoint_config is None: - shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"') - if current_checkpoint_info is not None: - shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore') - load_model(current_checkpoint_info, None) - return - shared.log.debug(f'Model dict loaded: {memory_stats()}') - sd_config = OmegaConf.load(checkpoint_config) - repair_config(sd_config) - timer.record("config") - shared.log.debug(f'Model config loaded: {memory_stats()}') - sd_model = None - stdout = io.StringIO() - if os.environ.get('SD_LDM_DEBUG', None) is not None: - sd_model = instantiate_from_config(sd_config.model) - else: - with contextlib.redirect_stdout(stdout): - sd_model = instantiate_from_config(sd_config.model) - for line in stdout.getvalue().splitlines(): - if len(line) > 0: - shared.log.info(f'LDM: {line.strip()}') - shared.log.debug(f"Model created from config: {checkpoint_config}") - sd_model.used_config = checkpoint_config - sd_model.has_accelerate = False - timer.record("create") - ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer) - if not ok: - model_data.sd_model = sd_model - current_checkpoint_info = None - unload_model_weights(op=op) - shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}') - if op == 'refiner': - # shared.opts.data['sd_model_refiner'] = 'None' - shared.opts.sd_model_refiner = 'None' - return - else: - shared.log.debug(f'Model weights loaded: {memory_stats()}') - timer.record("load") - if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - move_model(sd_model, devices.device) - timer.record("move") - shared.log.debug(f'Model weights moved: {memory_stats()}') - sd_hijack.model_hijack.hijack(sd_model) - timer.record("hijack") - sd_model.eval() - if op == 'refiner': - model_data.sd_refiner = sd_model - else: - model_data.sd_model = sd_model - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model - timer.record("embeddings") - script_callbacks.model_loaded_callback(sd_model) - timer.record("callbacks") - shared.log.info(f"Model loaded in {timer.summary()}") - current_checkpoint_info = None - devices.torch_gc(force=True) - shared.log.info(f'Model load finished: {memory_stats()}') - - def reload_text_encoder(initial=False): if initial and (shared.opts.sd_text_encoder is None or shared.opts.sd_text_encoder == 'None'): return # dont unload @@ -1342,35 +1027,6 @@ def reload_model_weights(sd_model=None, info=None, reuse_dict=False, op='model', return sd_model -def convert_to_faketensors(tensor): - try: - fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access - if hasattr(tensor, "weight"): - tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight)) - return tensor - except Exception: - pass - return tensor - - -def disable_offload(sd_model): - from accelerate.hooks import remove_hook_from_module - if not getattr(sd_model, 'has_accelerate', False): - return - if hasattr(sd_model, "_internal_dict"): - keys = sd_model._internal_dict.keys() # pylint: disable=protected-access - else: - keys = get_signature(sd_model).keys() - for module_name in keys: # pylint: disable=protected-access - module = getattr(sd_model, module_name, None) - if isinstance(module, torch.nn.Module): - network_layer_name = getattr(module, "network_layer_name", None) - module = remove_hook_from_module(module, recurse=True) - if network_layer_name: - module.network_layer_name = network_layer_name - sd_model.has_accelerate = False - - def clear_caches(): # shared.log.debug('Cache clear') if not shared.opts.lora_legacy: @@ -1411,16 +1067,3 @@ def unload_model_weights(op='model'): model_data.sd_refiner = None devices.torch_gc(force=True) shared.log.debug(f'Unload weights {op}: {memory_stats()}') - - -def path_to_repo(fn: str = ''): - if isinstance(fn, CheckpointInfo): - fn = fn.name - repo_id = fn.replace('\\', '/') - if 'models--' in repo_id: - repo_id = repo_id.split('models--')[-1] - repo_id = repo_id.split('/')[0] - repo_id = repo_id.split('/') - repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id) - repo_id = repo_id.replace('models--', '').replace('--', '/') - return repo_id diff --git a/modules/sd_models_legacy.py b/modules/sd_models_legacy.py new file mode 100644 index 000000000..ec21da7b7 --- /dev/null +++ b/modules/sd_models_legacy.py @@ -0,0 +1,207 @@ +import io +import os +import sys +import contextlib + +from modules import shared + + +sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' +sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' + + +def get_checkpoint_state_dict(checkpoint_info, timer): + from modules.sd_models_utils import read_state_dict + if not os.path.isfile(checkpoint_info.filename): + return None + """ + if checkpoint_info in checkpoints_loaded: + shared.log.info("Load model: cache") + checkpoints_loaded.move_to_end(checkpoint_info, last=True) # FIFO -> LRU cache + return checkpoints_loaded[checkpoint_info] + """ + res = read_state_dict(checkpoint_info.filename, what='model') + """ + if shared.opts.sd_checkpoint_cache > 0 and not shared.native: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = res + # clean up cache if limit is reached + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) + """ + timer.record("load") + return res + + +def repair_config(sd_config): + from modules import paths + if "use_ema" not in sd_config.model.params: + sd_config.model.params.use_ema = False + if shared.opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True if sys.platform != 'darwin' else False + if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: + sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" + # For UnCLIP-L, override the hardcoded karlo directory + if "noise_aug_config" in sd_config.model.params and "clip_stats_path" in sd_config.model.params.noise_aug_config.params: + karlo_path = os.path.join(paths.models_path, 'karlo') + sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path) + + +def load_model_weights(model, checkpoint_info, state_dict, timer): + # _pipeline, _model_type = sd_detect.detect_pipeline(checkpoint_info.path, 'model') + from modules.modeldata import model_data + from modules.memstats import memory_stats + from modules import devices, sd_vae + shared.log.debug(f'Load model: memory={memory_stats()}') + timer.record("hash") + if model_data.sd_dict == 'None': + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + try: + model.load_state_dict(state_dict, strict=False) + except Exception as e: + shared.log.error(f'Load model: path="{checkpoint_info.filename}"') + shared.log.error(' '.join(str(e).splitlines()[:2])) + return False + del state_dict + timer.record("apply") + if shared.opts.opt_channelslast: + import torch + model.to(memory_format=torch.channels_last) + timer.record("channels") + if not shared.opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) + if shared.opts.no_half_vae: # remove VAE from model when doing half() to prevent its weights from being converted to float16 + model.first_stage_model = None + if shared.opts.upcast_sampling and depth_model: # with don't convert the depth model weights to float16 + model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model + if shared.opts.cuda_cast_unet: + devices.dtype_unet = model.model.diffusion_model.dtype + else: + model.model.diffusion_model.to(devices.dtype_unet) + model.first_stage_model.to(devices.dtype_vae) + model.sd_model_hash = checkpoint_info.calculate_shorthash() + model.sd_model_checkpoint = checkpoint_info.filename + model.sd_checkpoint_info = checkpoint_info + model.is_sdxl = False # a1111 compatibility item + model.is_sd2 = hasattr(model.cond_stage_model, 'model') # a1111 compatibility item + model.is_sd1 = not hasattr(model.cond_stage_model, 'model') # a1111 compatibility item + model.logvar = model.logvar.to(devices.device) if hasattr(model, 'logvar') else None # fix for training + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 + sd_vae.delete_base_vae() + sd_vae.clear_loaded_vae() + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) + timer.record("vae") + return True + + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, timer=None, op='model'): + from ldm.util import instantiate_from_config + from omegaconf import OmegaConf + from modules import devices, lowvram, sd_hijack, sd_models_config, script_callbacks + from modules.timer import Timer + from modules.memstats import memory_stats + from modules.modeldata import model_data + from modules.sd_models import unload_model_weights, move_model + from modules.sd_checkpoint import select_checkpoint + checkpoint_info = checkpoint_info or select_checkpoint(op=op) + if checkpoint_info is None: + return + if op == 'model' or op == 'dict': + if (model_data.sd_model is not None) and (getattr(model_data.sd_model, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_model.sd_checkpoint_info.hash): # trying to load the same model + return + else: + if (model_data.sd_refiner is not None) and (getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) is not None) and (checkpoint_info.hash == model_data.sd_refiner.sd_checkpoint_info.hash): # trying to load the same model + return + shared.log.debug(f'Load {op}: name={checkpoint_info.filename} dict={already_loaded_state_dict is not None}') + if timer is None: + timer = Timer() + current_checkpoint_info = None + if op == 'model' or op == 'dict': + if model_data.sd_model is not None: + sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + current_checkpoint_info = getattr(model_data.sd_model, 'sd_checkpoint_info', None) + unload_model_weights(op=op) + else: + if model_data.sd_refiner is not None: + sd_hijack.model_hijack.undo_hijack(model_data.sd_refiner) + current_checkpoint_info = getattr(model_data.sd_refiner, 'sd_checkpoint_info', None) + unload_model_weights(op=op) + + if not shared.native: + from modules import sd_hijack_inpainting + sd_hijack_inpainting.do_inpainting_hijack() + + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + if state_dict is None or checkpoint_config is None: + shared.log.error(f'Load {op}: path="{checkpoint_info.filename}"') + if current_checkpoint_info is not None: + shared.log.info(f'Load {op}: previous="{current_checkpoint_info.filename}" restore') + load_model(current_checkpoint_info, None) + return + shared.log.debug(f'Model dict loaded: {memory_stats()}') + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + timer.record("config") + shared.log.debug(f'Model config loaded: {memory_stats()}') + sd_model = None + stdout = io.StringIO() + if os.environ.get('SD_LDM_DEBUG', None) is not None: + sd_model = instantiate_from_config(sd_config.model) + else: + with contextlib.redirect_stdout(stdout): + sd_model = instantiate_from_config(sd_config.model) + for line in stdout.getvalue().splitlines(): + if len(line) > 0: + shared.log.info(f'LDM: {line.strip()}') + shared.log.debug(f"Model created from config: {checkpoint_config}") + sd_model.used_config = checkpoint_config + sd_model.has_accelerate = False + timer.record("create") + ok = load_model_weights(sd_model, checkpoint_info, state_dict, timer) + if not ok: + model_data.sd_model = sd_model + current_checkpoint_info = None + unload_model_weights(op=op) + shared.log.debug(f'Model weights unloaded: {memory_stats()} op={op}') + if op == 'refiner': + # shared.opts.data['sd_model_refiner'] = 'None' + shared.opts.sd_model_refiner = 'None' + return + else: + shared.log.debug(f'Model weights loaded: {memory_stats()}') + timer.record("load") + if not shared.native and (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): + lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) + else: + move_model(sd_model, devices.device) + timer.record("move") + shared.log.debug(f'Model weights moved: {memory_stats()}') + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() + if op == 'refiner': + model_data.sd_refiner = sd_model + else: + model_data.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("embeddings") + script_callbacks.model_loaded_callback(sd_model) + timer.record("callbacks") + shared.log.info(f"Model loaded in {timer.summary()}") + current_checkpoint_info = None + devices.torch_gc(force=True) + shared.log.info(f'Model load finished: {memory_stats()}') diff --git a/modules/sd_models_utils.py b/modules/sd_models_utils.py new file mode 100644 index 000000000..0ff903483 --- /dev/null +++ b/modules/sd_models_utils.py @@ -0,0 +1,151 @@ +import io +import json +import inspect +import os.path +from rich import progress # pylint: disable=redefined-builtin +import torch +import safetensors.torch + +from modules import paths, shared, errors +from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import +from modules.sd_offload import disable_offload, set_diffuser_offload, apply_balanced_offload, set_accelerate # pylint: disable=unused-import +from modules.sd_models_legacy import get_checkpoint_state_dict, load_model_weights, load_model, repair_config # pylint: disable=unused-import + + +class NoWatermark: + def apply_watermark(self, img): + return img + + +def get_signature(cls): + signature = inspect.signature(cls.__init__, follow_wrapped=True, eval_str=True) + return signature.parameters + + +def get_call(cls): + if cls is None: + return [] + signature = inspect.signature(cls.__call__, follow_wrapped=True, eval_str=True) + return signature.parameters + + +def path_to_repo(fn: str = ''): + if isinstance(fn, CheckpointInfo): + fn = fn.name + repo_id = fn.replace('\\', '/') + if 'models--' in repo_id: + repo_id = repo_id.split('models--')[-1] + repo_id = repo_id.split('/')[0] + repo_id = repo_id.split('/') + repo_id = '/'.join(repo_id[-2:] if len(repo_id) > 1 else repo_id) + repo_id = repo_id.replace('models--', '').replace('--', '/') + return repo_id + + +def convert_to_faketensors(tensor): + try: + fake_module = torch._subclasses.fake_tensor.FakeTensorMode(allow_non_fake_inputs=True) # pylint: disable=protected-access + if hasattr(tensor, "weight"): + tensor.weight = torch.nn.Parameter(fake_module.from_tensor(tensor.weight)) + return tensor + except Exception: + pass + return tensor + + +def read_state_dict(checkpoint_file, map_location=None, what:str='model'): # pylint: disable=unused-argument + if not os.path.isfile(checkpoint_file): + shared.log.error(f'Load dict: path="{checkpoint_file}" not a file') + return None + try: + pl_sd = None + with progress.open(checkpoint_file, 'rb', description=f'[cyan]Load {what}: [yellow]{checkpoint_file}', auto_refresh=True, console=shared.console) as f: + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".ckpt" and shared.opts.sd_disable_ckpt: + shared.log.warning(f"Checkpoint loading disabled: {checkpoint_file}") + return None + if shared.opts.stream_load: + if extension.lower() == ".safetensors": + # shared.log.debug('Model weights loading: type=safetensors mode=buffered') + buffer = f.read() + pl_sd = safetensors.torch.load(buffer) + else: + # shared.log.debug('Model weights loading: type=checkpoint mode=buffered') + buffer = io.BytesIO(f.read()) + pl_sd = torch.load(buffer, map_location='cpu') + else: + if extension.lower() == ".safetensors": + # shared.log.debug('Model weights loading: type=safetensors mode=mmap') + pl_sd = safetensors.torch.load_file(checkpoint_file, device='cpu') + else: + # shared.log.debug('Model weights loading: type=checkpoint mode=direct') + pl_sd = torch.load(f, map_location='cpu') + sd = get_state_dict_from_checkpoint(pl_sd) + del pl_sd + except Exception as e: + errors.display(e, f'Load model: {checkpoint_file}') + sd = None + return sd + + +def get_state_dict_from_checkpoint(pl_sd): + checkpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', + } + + def transform_checkpoint_dict_key(k): + for text, replacement in checkpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + return k + + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + if new_key is not None: + sd[new_key] = v + pl_sd.clear() + pl_sd.update(sd) + return pl_sd + + +def patch_diffuser_config(sd_model, model_file): + def load_config(fn, k): + model_file = os.path.splitext(fn)[0] + cfg_file = f'{model_file}_{k}.json' + try: + if os.path.exists(cfg_file): + with open(cfg_file, 'r', encoding='utf-8') as f: + return json.load(f) + cfg_file = f'{os.path.join(paths.sd_configs_path, os.path.basename(model_file))}_{k}.json' + if os.path.exists(cfg_file): + with open(cfg_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception: + pass + return {} + + if sd_model is None: + return sd_model + if hasattr(sd_model, 'unet') and hasattr(sd_model.unet, 'config') and 'inpaint' in model_file.lower(): + sd_model.unet.config.in_channels = 9 + if not hasattr(sd_model, '_internal_dict'): + return sd_model + for c in sd_model._internal_dict.keys(): # pylint: disable=protected-access + component = getattr(sd_model, c, None) + if hasattr(component, 'config'): + override = load_config(model_file, c) + updated = {} + for k, v in override.items(): + if k.startswith('_'): + continue + if v != component.config.get(k, None): + if hasattr(component.config, '__frozen'): + component.config.__frozen = False # pylint: disable=protected-access + component.config[k] = v + updated[k] = v + return sd_model diff --git a/modules/sd_offload.py b/modules/sd_offload.py index 33aea13a3..f9d01528c 100644 --- a/modules/sd_offload.py +++ b/modules/sd_offload.py @@ -4,6 +4,7 @@ import inspect import torch import accelerate + from modules import shared, devices, errors from modules.timer import process as process_timer @@ -18,6 +19,24 @@ def get_signature(cls): return signature.parameters +def disable_offload(sd_model): + from accelerate.hooks import remove_hook_from_module + if not getattr(sd_model, 'has_accelerate', False): + return + if hasattr(sd_model, "_internal_dict"): + keys = sd_model._internal_dict.keys() # pylint: disable=protected-access + else: + keys = get_signature(sd_model).keys() + for module_name in keys: # pylint: disable=protected-access + module = getattr(sd_model, module_name, None) + if isinstance(module, torch.nn.Module): + network_layer_name = getattr(module, "network_layer_name", None) + module = remove_hook_from_module(module, recurse=True) + if network_layer_name: + module.network_layer_name = network_layer_name + sd_model.has_accelerate = False + + def set_accelerate(sd_model): def set_accelerate_to_module(model): if hasattr(model, "pipe"): @@ -36,7 +55,7 @@ def set_accelerate_to_module(model): set_accelerate_to_module(sd_model.decoder_pipe) -def set_diffuser_offload(sd_model, op: str = 'model'): +def set_diffuser_offload(sd_model, op:str='model', quiet:bool=False): t0 = time.time() if not shared.native: shared.log.warning('Attempting to use offload with backend=original') @@ -50,13 +69,13 @@ def set_diffuser_offload(sd_model, op: str = 'model'): if shared.sd_model_type in should_offload: shared.log.warning(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} type={shared.sd_model.__class__.__name__} large model') else: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') if hasattr(sd_model, 'maybe_free_model_hooks'): sd_model.maybe_free_model_hooks() sd_model.has_accelerate = False if shared.opts.diffusers_offload_mode == "model" and hasattr(sd_model, "enable_model_cpu_offload"): try: - shared.log.debug(f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') + shared.log.quiet(quiet, f'Setting {op}: offload={shared.opts.diffusers_offload_mode} limit={shared.opts.cuda_mem_fraction}') if shared.opts.diffusers_move_base or shared.opts.diffusers_move_unet or shared.opts.diffusers_move_refiner: shared.opts.diffusers_move_base = False shared.opts.diffusers_move_unet = False diff --git a/modules/shared.py b/modules/shared.py index f1a9483c2..635ca86dd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -211,14 +211,12 @@ def default(obj): # early select backend -default_backend = 'diffusers' early_opts = readfile(cmd_opts.config, silent=True) -early_backend = early_opts.get('sd_backend', default_backend) -backend = Backend.DIFFUSERS if early_backend.lower() == 'diffusers' else Backend.ORIGINAL +early_backend = early_opts.get('sd_backend', 'diffusers') +backend = Backend.ORIGINAL if early_backend.lower() == 'original' else Backend.DIFFUSERS if cmd_opts.backend is not None: # override with args - backend = Backend.DIFFUSERS if cmd_opts.backend.lower() == 'diffusers' else Backend.ORIGINAL + backend = Backend.ORIGINAL if cmd_opts.backend.lower() == 'original' else Backend.DIFFUSERS if cmd_opts.use_openvino: # override for openvino - backend = Backend.DIFFUSERS from modules.intel.openvino import get_device_list as get_openvino_device_list # pylint: disable=ungrouped-imports elif cmd_opts.use_ipex or devices.has_xpu(): from modules.intel.ipex import ipex_init @@ -226,15 +224,14 @@ def default(obj): if not ok: log.error(f'IPEX initialization failed: {e}') elif cmd_opts.use_directml: - name = 'directml' from modules.dml import directml_init ok, e = directml_init() if not ok: log.error(f'DirectML initialization failed: {e}') devices.backend = devices.get_backend(cmd_opts) devices.device = devices.get_optimal_device() -cpu_memory = round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2) mem_stat = memory_stats() +cpu_memory = round(psutil.virtual_memory().total / 1024 / 1024 / 1024, 2) gpu_memory = mem_stat['gpu']['total'] if "gpu" in mem_stat else 0 native = backend == Backend.DIFFUSERS if not files_cache.do_cache_folders: @@ -475,7 +472,7 @@ def get_default_modes(): startup_offload_mode, startup_cross_attention, startup_sdp_options = get_default_modes() options_templates.update(options_section(('sd', "Models & Loading"), { - "sd_backend": OptionInfo(default_backend, "Execution backend", gr.Radio, {"choices": ["diffusers", "original"] }), + "sd_backend": OptionInfo('diffusers', "Execution backend", gr.Radio, {"choices": ['diffusers', 'original'] }), "diffusers_pipeline": OptionInfo('Autodetect', 'Model pipeline', gr.Dropdown, lambda: {"choices": list(shared_items.get_pipelines()), "visible": native}), "sd_model_checkpoint": OptionInfo(default_checkpoint, "Base model", DropdownEditable, lambda: {"choices": list_checkpoint_titles()}, refresh=refresh_checkpoints), "sd_model_refiner": OptionInfo('None', "Refiner model", gr.Dropdown, lambda: {"choices": ['None'] + list_checkpoint_titles()}, refresh=refresh_checkpoints), @@ -513,7 +510,7 @@ def get_default_modes(): "diffusers_vae_tile_overlap": OptionInfo(0.25, "VAE tile overlap", gr.Slider, {"minimum": 0, "maximum": 0.95, "step": 0.05 }), "sd_vae_sliced_encode": OptionInfo(False, "VAE sliced encode", gr.Checkbox, {"visible": not native}), "nan_skip": OptionInfo(False, "Skip Generation if NaN found in latents", gr.Checkbox), - "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values"), + "rollback_vae": OptionInfo(False, "Attempt VAE roll back for NaN values", gr.Checkbox, {"visible": not native}), })) options_templates.update(options_section(('text_encoder', "Text Encoder"), { diff --git a/modules/ui_control_helpers.py b/modules/ui_control_helpers.py index 1eeae23c2..cb25582c6 100644 --- a/modules/ui_control_helpers.py +++ b/modules/ui_control_helpers.py @@ -6,8 +6,8 @@ gr_height = None max_units = shared.opts.control_max_units -debug = shared.log.trace if os.environ.get('SD_CONTROL_DEBUG', None) is not None else lambda *args, **kwargs: None -debug('Trace: CONTROL') +debug = os.environ.get('SD_CONTROL_DEBUG', None) is not None +debug_log = shared.log.trace if debug else lambda *args, **kwargs: None # state variables busy = False # used to synchronize select_input and generate_click @@ -127,7 +127,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i busy = False # debug('Control input: none') return [gr.Tabs.update(), None, ''] - debug(f'Control select input: source={selected_input} init={init_image} type={init_type} mode={input_mode}') + debug_log(f'Control select input: source={selected_input} init={init_image} type={init_type} mode={input_mode}') input_type = type(selected_input) input_mask = None status = 'Control input | Unknown' @@ -168,7 +168,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i res = [gr.Tabs.update(selected='out-gallery'), input_mask, status] else: # unknown input_source = None - shared.log.debug(f'Control input: type={input_type} input={input_source}') + debug_log(f'Control input: type={input_type} input={input_source}') # init inputs: optional if init_type == 0: # Control only input_init = None @@ -176,7 +176,7 @@ def select_input(input_mode, input_image, init_image, init_type, input_resize, i input_init = None elif init_type == 2: # Separate init image input_init = [init_image] - debug(f'Control select input: source={input_source} init={input_init} mask={input_mask} mode={input_mode}') + debug_log(f'Control select input: source={input_source} init={input_init} mask={input_mask} mode={input_mode}') busy = False return res @@ -191,7 +191,7 @@ def video_type_change(video_type): def copy_input(mode_from, mode_to, input_image, input_resize, input_inpaint): - debug(f'Control transfter input: from={mode_from} to={mode_to} image={input_image} resize={input_resize} inpaint={input_inpaint}') + debug_log(f'Control transfter input: from={mode_from} to={mode_to} image={input_image} resize={input_resize} inpaint={input_inpaint}') def getimg(ctrl): if ctrl is None: return None diff --git a/scripts/pulid_ext.py b/scripts/pulid_ext.py index ee08e348b..4f2890e52 100644 --- a/scripts/pulid_ext.py +++ b/scripts/pulid_ext.py @@ -3,10 +3,8 @@ import time import contextlib import gradio as gr -import numpy as np from PIL import Image from modules import shared, devices, errors, scripts, processing, processing_helpers, sd_models -from modules.api.api import decode_base64_to_image debug = os.environ.get('SD_PULID_DEBUG', None) is not None @@ -59,12 +57,16 @@ def fun(p, x, xs): # pylint: disable=unused-argument xyz_classes.axis_options.append(option) + def decode_image(self, b64): + from modules.api.api import decode_base64_to_image + return decode_base64_to_image(b64) + def load_images(self, files): uploaded_images.clear() for file in files or []: try: if isinstance(file, str): - image = decode_base64_to_image(file) + image = self.decode_image(file) elif isinstance(file, Image.Image): image = file elif isinstance(file, dict) and 'name' in file: @@ -113,16 +115,17 @@ def run( version: str = 'v1.1' ): # pylint: disable=arguments-differ, unused-argument images = [] + import numpy as np try: if gallery is None or (isinstance(gallery, list) and len(gallery) == 0): images = getattr(p, 'pulid_images', uploaded_images) - images = [decode_base64_to_image(image) if isinstance(image, str) else image for image in images] + images = [self.decode_image(image) if isinstance(image, str) else image for image in images] elif isinstance(gallery[0], dict): images = [Image.open(f['name']) for f in gallery] elif isinstance(gallery, str): - images = [decode_base64_to_image(gallery)] + images = [self.decode_image(gallery)] elif isinstance(gallery[0], str): - images = [decode_base64_to_image(f) for f in gallery] + images = [self.decode_image(f) for f in gallery] else: images = gallery images = [np.array(image) for image in images] diff --git a/webui.py b/webui.py index 2c70192f1..f104bbee8 100644 --- a/webui.py +++ b/webui.py @@ -11,13 +11,9 @@ from threading import Thread import modules.hashes import modules.loader -import torch # pylint: disable=wrong-import-order -from modules import timer, errors, paths # pylint: disable=unused-import + from installer import log, git_commit, custom_excepthook -# import ldm.modules.encoders.modules # pylint: disable=unused-import, wrong-import-order -from modules import shared, extensions, gr_tempdir, modelloader # pylint: disable=ungrouped-imports -from modules import extra_networks, ui_extra_networks # pylint: disable=ungrouped-imports -from modules.paths import create_paths +from modules import timer, paths, shared, extensions, gr_tempdir, modelloader from modules.call_queue import queue_lock, wrap_queued_call, wrap_gradio_gpu_call # pylint: disable=unused-import import modules.devices import modules.sd_checkpoint @@ -33,23 +29,28 @@ import modules.txt2img import modules.img2img import modules.upscaler +import modules.extra_networks +import modules.ui_extra_networks import modules.textual_inversion.textual_inversion import modules.hypernetworks.hypernetwork import modules.script_callbacks -from modules.api.middleware import setup_middleware -from modules.shared import cmd_opts, opts # pylint: disable=unused-import +import modules.api.middleware + +if not modules.loader.initialized: + timer.startup.record("libraries") + import modules.sd_hijack # runs conditional load of ldm if not shared.native + timer.startup.record("ldm") +modules.loader.initialized = True sys.excepthook = custom_excepthook local_url = None state = shared.state backend = shared.backend -if not modules.loader.initialized: - timer.startup.record("libraries") -if cmd_opts.server_name: - server_name = cmd_opts.server_name +if shared.cmd_opts.server_name: + server_name = shared.cmd_opts.server_name else: - server_name = "0.0.0.0" if cmd_opts.listen else None + server_name = "0.0.0.0" if shared.cmd_opts.listen else None fastapi_args = { "version": f'0.0.{git_commit}', "title": "SD.Next", @@ -60,30 +61,12 @@ # "redoc_url": "/redocs" if cmd_opts.docs else None, } -import modules.sd_hijack -timer.startup.record("ldm") -modules.loader.initialized = True - - -def check_rollback_vae(): - if shared.cmd_opts.rollback_vae: - if not torch.cuda.is_available(): - log.error("Rollback VAE functionality requires compatible GPU") - shared.cmd_opts.rollback_vae = False - elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'): - log.error("Rollback VAE functionality requires Torch 2.1 or higher") - shared.cmd_opts.rollback_vae = False - elif 0 < torch.cuda.get_device_capability()[0] < 8: - log.error('Rollback VAE functionality device capabilities not met') - shared.cmd_opts.rollback_vae = False - def initialize(): log.debug('Initializing') modules.sd_checkpoint.init_metadata() modules.hashes.init_cache() - check_rollback_vae() log.debug(f'Huggingface cache: path="{shared.opts.hfcache_dir}"') @@ -136,20 +119,20 @@ def initialize(): shared.reload_hypernetworks() timer.startup.record("hypernetworks") - ui_extra_networks.initialize() - ui_extra_networks.register_pages() - extra_networks.initialize() - extra_networks.register_default_extra_networks() + modules.ui_extra_networks.initialize() + modules.ui_extra_networks.register_pages() + modules.extra_networks.initialize() + modules.extra_networks.register_default_extra_networks() timer.startup.record("networks") - if cmd_opts.tls_keyfile is not None and cmd_opts.tls_certfile is not None: + if shared.cmd_opts.tls_keyfile is not None and shared.cmd_opts.tls_certfile is not None: try: - if not os.path.exists(cmd_opts.tls_keyfile): + if not os.path.exists(shared.cmd_opts.tls_keyfile): log.error("Invalid path to TLS keyfile given") - if not os.path.exists(cmd_opts.tls_certfile): - log.error(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + if not os.path.exists(shared.cmd_opts.tls_certfile): + log.error(f"Invalid path to TLS certfile: '{shared.cmd_opts.tls_certfile}'") except TypeError: - cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + shared.cmd_opts.tls_keyfile = shared.cmd_opts.tls_certfile = None log.error("TLS setup invalid, running webui without TLS") else: log.info("Running with TLS") @@ -231,7 +214,7 @@ def start_common(): log.info(f'Using data path: {shared.cmd_opts.data_dir}') if shared.cmd_opts.models_dir is not None and len(shared.cmd_opts.models_dir) > 0 and shared.cmd_opts.models_dir != 'models': log.info(f'Models path: {shared.cmd_opts.models_dir}') - create_paths(shared.opts) + paths.create_paths(shared.opts) async_policy() initialize() try: @@ -251,20 +234,20 @@ def start_ui(): timer.startup.record("before-ui") shared.demo = modules.ui.create_ui(timer.startup) timer.startup.record("ui") - if cmd_opts.disable_queue: + if shared.cmd_opts.disable_queue: log.info('Server queues disabled') shared.demo.progress_tracking = False else: shared.demo.queue(concurrency_count=64) gradio_auth_creds = [] - if cmd_opts.auth: - gradio_auth_creds += [x.strip() for x in cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()] - if cmd_opts.auth_file: - if not os.path.exists(cmd_opts.auth_file): - log.error(f"Invalid path to auth file: '{cmd_opts.auth_file}'") + if shared.cmd_opts.auth: + gradio_auth_creds += [x.strip() for x in shared.cmd_opts.auth.strip('"').replace('\n', '').split(',') if x.strip()] + if shared.cmd_opts.auth_file: + if not os.path.exists(shared.cmd_opts.auth_file): + log.error(f"Invalid path to auth file: '{shared.cmd_opts.auth_file}'") else: - with open(cmd_opts.auth_file, 'r', encoding="utf8") as file: + with open(shared.cmd_opts.auth_file, 'r', encoding="utf8") as file: for line in file.readlines(): gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] if len(gradio_auth_creds) > 0: @@ -273,19 +256,19 @@ def start_ui(): global local_url # pylint: disable=global-statement stdout = io.StringIO() allowed_paths = [os.path.dirname(__file__)] - if cmd_opts.data_dir is not None and os.path.isdir(cmd_opts.data_dir): - allowed_paths.append(cmd_opts.data_dir) - if cmd_opts.allowed_paths is not None: - allowed_paths += [p for p in cmd_opts.allowed_paths if os.path.isdir(p)] + if shared.cmd_opts.data_dir is not None and os.path.isdir(shared.cmd_opts.data_dir): + allowed_paths.append(shared.cmd_opts.data_dir) + if shared.cmd_opts.allowed_paths is not None: + allowed_paths += [p for p in shared.cmd_opts.allowed_paths if os.path.isdir(p)] shared.log.debug(f'Root paths: {allowed_paths}') with contextlib.redirect_stdout(stdout): app, local_url, share_url = shared.demo.launch( # app is FastAPI(Starlette) instance - share=cmd_opts.share, + share=shared.cmd_opts.share, server_name=server_name, - server_port=cmd_opts.port if cmd_opts.port != 7860 else None, - ssl_keyfile=cmd_opts.tls_keyfile, - ssl_certfile=cmd_opts.tls_certfile, - ssl_verify=not cmd_opts.tls_selfsign, + server_port=shared.cmd_opts.port if shared.cmd_opts.port != 7860 else None, + ssl_keyfile=shared.cmd_opts.tls_keyfile, + ssl_certfile=shared.cmd_opts.tls_certfile, + ssl_verify=not shared.cmd_opts.tls_selfsign, debug=False, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, prevent_thread_lock=True, @@ -295,24 +278,24 @@ def start_ui(): favicon_path='html/favicon.svg', allowed_paths=allowed_paths, app_kwargs=fastapi_args, - _frontend=True and cmd_opts.share, + _frontend=True and shared.cmd_opts.share, ) - if cmd_opts.data_dir is not None: - gr_tempdir.register_tmp_file(shared.demo, os.path.join(cmd_opts.data_dir, 'x')) + if shared.cmd_opts.data_dir is not None: + gr_tempdir.register_tmp_file(shared.demo, os.path.join(shared.cmd_opts.data_dir, 'x')) shared.log.info(f'Local URL: {local_url}') - if cmd_opts.docs: + if shared.cmd_opts.docs: shared.log.info(f'API Docs: {local_url[:-1]}/docs') # pylint: disable=unsubscriptable-object shared.log.info(f'API ReDocs: {local_url[:-1]}/redocs') # pylint: disable=unsubscriptable-object if share_url is not None: shared.log.info(f'Share URL: {share_url}') # shared.log.debug(f'Gradio functions: registered={len(shared.demo.fns)}') shared.demo.server.wants_restart = False - setup_middleware(app, cmd_opts) + modules.api.middleware.setup_middleware(app, shared.cmd_opts) - if cmd_opts.subpath: + if shared.cmd_opts.subpath: import gradio - gradio.mount_gradio_app(app, shared.demo, path=f"/{cmd_opts.subpath}") - shared.log.info(f'Redirector mounted: /{cmd_opts.subpath}') + gradio.mount_gradio_app(app, shared.demo, path=f"/{shared.cmd_opts.subpath}") + shared.log.info(f'Redirector mounted: /{shared.cmd_opts.subpath}') timer.startup.record("launch") @@ -320,7 +303,7 @@ def start_ui(): shared.api = create_api(app) timer.startup.record("api") - ui_extra_networks.init_api(app) + modules.ui_extra_networks.init_api(app) modules.script_callbacks.app_started_callback(shared.demo, app) timer.startup.record("app-started") @@ -345,7 +328,7 @@ def webui(restart=False): modules.sd_models.write_metadata() load_model() shared.opts.save(shared.config_filename) - if cmd_opts.profile: + if shared.cmd_opts.profile: for k, v in modules.script_callbacks.callback_map.items(): shared.log.debug(f'Registered callbacks: {k}={len(v)} {[c.script for c in v]}') debug = log.trace if os.environ.get('SD_SCRIPT_DEBUG', None) is not None else lambda *args, **kwargs: None @@ -357,7 +340,7 @@ def webui(restart=False): debug(f' {m}') modules.script_callbacks.print_timers() - if cmd_opts.profile: + if shared.cmd_opts.profile: log.info(f"Launch time: {timer.launch.summary(min_time=0)}") log.info(f"Installer time: {timer.init.summary(min_time=0)}") log.info(f"Startup time: {timer.startup.summary(min_time=0)}") @@ -374,8 +357,8 @@ def webui(restart=False): continue logger.handlers = log.handlers # autolaunch only on initial start - if (shared.opts.autolaunch or cmd_opts.autolaunch) and local_url is not None: - cmd_opts.autolaunch = False + if (shared.opts.autolaunch or shared.cmd_opts.autolaunch) and local_url is not None: + shared.cmd_opts.autolaunch = False shared.log.info('Launching browser') import webbrowser webbrowser.open(local_url, new=2, autoraise=True) @@ -390,7 +373,7 @@ def api_only(): start_common() from fastapi import FastAPI app = FastAPI(**fastapi_args) - setup_middleware(app, cmd_opts) + modules.api.middleware.setup_middleware(app, shared.cmd_opts) shared.api = create_api(app) shared.api.wants_restart = False modules.script_callbacks.app_started_callback(None, app) @@ -401,7 +384,7 @@ def api_only(): if __name__ == "__main__": - if cmd_opts.api_only: + if shared.cmd_opts.api_only: api_only() else: webui() diff --git a/wiki b/wiki index 7c2400e9d..29e37ad76 160000 --- a/wiki +++ b/wiki @@ -1 +1 @@ -Subproject commit 7c2400e9dc5dee3c52eac6bbfa88352f7815454a +Subproject commit 29e37ad766904bc04f9e9701c2503a3f0898964a