Skip to content

Commit

Permalink
Merge pull request #84 from noskill/cnetref
Browse files Browse the repository at this point in the history
adapt existing tests for flux models
  • Loading branch information
noskill authored Oct 4, 2024
2 parents 6d4b42b + a02b8db commit 8b44048
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 63 deletions.
2 changes: 2 additions & 0 deletions multigen/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def weightshare_copy(pipe):
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy


Expand Down
42 changes: 33 additions & 9 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, model_id: str,
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe_passed = pipe is not None
self.pipe = pipe
self._scheduler = None
self._hypernets = []
Expand All @@ -125,7 +126,8 @@ def __init__(self, model_id: str,
if mt != model_type:
raise RuntimeError(f"passed model type {self.model_type} doesn't match actual type {mt}")

self._initialize_pipe(device, offload_device)
if not pipe_passed:
self._initialize_pipe(device, offload_device)
self.lpw = lpw
self._loras = []

Expand Down Expand Up @@ -155,6 +157,7 @@ def _get_model_type(self):
def _initialize_pipe(self, device, offload_device):
# sometimes text encoder is on a different device
# if self.pipe.device != device:
logging.debug(f"initialising pipe to device {device}: offload_device {offload_device}")
self.pipe.to(device)
# self.pipe.enable_attention_slicing()
# self.pipe.enable_vae_slicing()
Expand All @@ -164,6 +167,7 @@ def _initialize_pipe(self, device, offload_device):
if self.model_type == ModelType.FLUX:
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
else:
try:
import xformers
Expand All @@ -172,19 +176,21 @@ def _initialize_pipe(self, device, offload_device):
logging.warning("xformers not found, can't use efficient attention")

def _load_pipeline(self, sd_pipe_class, model_type, args):
logging.debug(f"loading pipeline from {self._model_id} with {args}")
if sd_pipe_class is None:
if self._model_id.endswith('.safetensors'):
if model_type is None:
raise RuntimeError(f"model_type is not specified for safetensors file {self._model_id}")
pipe_class = self._class if model_type == ModelType.SD else self._classxl
return pipe_class.from_single_file(self._model_id, **args)
result = pipe_class.from_single_file(self._model_id, **args)
else:
return self._autopipeline.from_pretrained(self._model_id, **args)
result = self._autopipeline.from_pretrained(self._model_id, **args)
else:
if self._model_id.endswith('.safetensors'):
return sd_pipe_class.from_single_file(self._model_id, **args)
result = sd_pipe_class.from_single_file(self._model_id, **args)
else:
return sd_pipe_class.from_pretrained(self._model_id, **args)
result = sd_pipe_class.from_pretrained(self._model_id, **args)
return result

@property
def scheduler(self):
Expand Down Expand Up @@ -724,7 +730,7 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
if model_id.endswith('.safetensors'):
if self.model_type is None:
raise RuntimeError(f"model type is not specified for safetensors file {model_id}")
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None))
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None), args.get('torch_dtype', None))
super().__init__(model_id=model_id, pipe=pipe, controlnet=cnets, model_type=model_type, **args)
else:
super().__init__(model_id=model_id, pipe=pipe, controlnet=cnets, model_type=model_type, **args)
Expand All @@ -738,22 +744,26 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
else:
raise RuntimeError(f"Unexpected model type {type(self.pipe)}")
self.model_type = t_model_type
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None))
logging.debug(f"from_pipe source dtype {self.pipe.dtype}")
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None), self.pipe.dtype)
prev_dtype = self.pipe.dtype
if self.model_type == ModelType.SDXL:
self.pipe = self._classxl.from_pipe(self.pipe, controlnet=cnets)
elif self.model_type == ModelType.FLUX:
self.pipe = self._classflux.from_pipe(self.pipe, controlnet=cnets[0])
else:
self.pipe = self._class.from_pipe(self.pipe, controlnet=cnets)
logging.debug(f"after from_pipe result dtype {self.pipe.dtype}")
for cnet in cnets:
cnet.to(self.pipe.dtype)
cnet.to(prev_dtype)
logging.debug(f'moving cnet {id(cnet)} to self.pipe.dtype {prev_dtype}')
if 'offload_device' not in args:
cnet.to(self.pipe.device)
else:
# don't load anything, just reuse pipe
super().__init__(model_id=model_id, pipe=pipe, **args)

def _load_cnets(self, cnets, cnet_ids, offload_device=None):
def _load_cnets(self, cnets, cnet_ids, offload_device=None, dtype=None):
if self.model_type == ModelType.FLUX:
ControlNet = FluxControlNetModel
else:
Expand All @@ -773,9 +783,18 @@ def _load_cnets(self, cnets, cnet_ids, offload_device=None):
else:
cnets.append(ControlNet.from_pretrained(c, torch_dtype=torch_dtype))
if offload_device is not None:
# controlnet should be on the same device where main model is working
dev = torch.device('cuda', offload_device)
logging.debug(f'moving cnets to offload device {dev}')
for cnet in cnets:
cnet.to(dev)
else:
logging.debug('offload device is None')
for cnet in cnets:
logging.debug(f"cnet dtype {cnet.dtype}")
if dtype is not None:
logging.debug(f"changing to {dtype}")
cnet.to(dtype)
return cnets

def get_cmodels(self):
Expand Down Expand Up @@ -832,6 +851,8 @@ def setup(self, fimage, width=None, height=None,
self._input_image = [image]
if cscales is None:
cscales = [self.get_default_cond_scales()[c] for c in self.ctypes]
if self.model_type == ModelType.FLUX and hasattr(cscales, '__len__'):
cscales = cscales[0] # multiple controlnets are not yet supported
self.pipe_params.update({
"width": image.size[0] if width is None else width,
"height": image.size[1] if height is None else height,
Expand Down Expand Up @@ -905,6 +926,9 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
Additional arguments passed to the Cond2ImPipe constructor.
"""
super().__init__(model_id=model_id, pipe=pipe, ctypes=ctypes, model_type=model_type, **args)
logging.debug("CIm2Im backend pipe was constructed")
logging.debug(f"self.pipe.dtype = {self.pipe.dtype}")
logging.debug(f"self.pipe.controlnet.dtype = {self.pipe.controlnet.dtype}")
self.processor = None
self.body_estimation = None
self.draw_bodypose = None
Expand Down
8 changes: 7 additions & 1 deletion multigen/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from . import util
from .prompting import Cfgen
import logging


class GenSession:
Expand Down Expand Up @@ -84,12 +85,14 @@ def gen_sess(self, add_count = 0, save_img=True,
# collecting images to return if requested or images are not saved
if not save_img or force_collect:
images = []
logging.info(f"add count = {add_count}")
jk = 0
for inputs in self.confg:
self.last_index = self.confg.count - 1
self.last_conf = {**inputs}
# TODO: multiple inputs?
inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])

logging.debug("start generation")
image = self.pipe.gen(inputs)
if save_img:
self.last_img_name = self.get_last_file_prefix() + ".png"
Expand All @@ -103,5 +106,8 @@ def gen_sess(self, add_count = 0, save_img=True,
if save_img and not drop_cfg:
self.save_last_conf()
if callback is not None:
logging.debug("call callback after generation")
callback()
jk += 1
logging.debug(f"done iteration {jk}")
return images
9 changes: 5 additions & 4 deletions multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None):
if model_type == ModelType.SDXL:
cls = pipe_class._classxl
elif model_type == ModelType.FLUX:
cls = pipe_class._flux
# use offload by default for now
cls = pipe_class._classflux
if device.type == 'cuda':
offload_device = device.index
device = torch.device('cpu')
device = torch.device('cpu', 0)
else:
cls = pipe_class._class
pipeline = self._loader.load_pipeline(cls, model_id, torch_dtype=torch.bfloat16,
device=device)
self.logger.debug(f'requested {cls} {model_id} on device {device}, got {pipeline.device}')
assert pipeline.device == device
pipe = pipe_class(model_id, pipe=pipeline, device=device, offload_device=offload_device)
if offload_device is None:
assert pipeline.device == device
Expand Down Expand Up @@ -164,7 +164,8 @@ def _update(sess, job, gs):
data['finish_callback']()
except (RuntimeError, TypeError, NotImplementedError) as e:
self.logger.error("error in generation", exc_info=e)
self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}")
if hasattr(pipe.pipe, '_offload_gpu_id'):
self.logger.error(f"offload_device {pipe.pipe._offload_gpu_id}")
if 'finish_callback' in data:
data['finish_callback']("Can't generate image due to error")
except Exception as e:
Expand Down
Loading

0 comments on commit 8b44048

Please sign in to comment.