Skip to content

Commit

Permalink
fix cpu offload setup
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed Oct 16, 2024
1 parent 8b44048 commit 4892cd6
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 45 deletions.
2 changes: 1 addition & 1 deletion multigen/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setup_logger(path='log_file.log'):

ch.addFilter(thread_id_filter)
fh.addFilter(thread_id_filter)
formatter = logging.Formatter('%(asctime)s - %(thread)d - %(levelname)s - %(message)s')
formatter = logging.Formatter('%(asctime)s - %(thread)d - %(levelname)s - %(funcName)20s() - %(message)s')
fh.setFormatter(formatter)
ch.setFormatter(formatter)

Expand Down
12 changes: 6 additions & 6 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(self, model_id: str,
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipe_passed = pipe is not None
self.pipe = pipe
self._scheduler = None
self._hypernets = []
Expand All @@ -126,8 +125,7 @@ def __init__(self, model_id: str,
if mt != model_type:
raise RuntimeError(f"passed model type {self.model_type} doesn't match actual type {mt}")

if not pipe_passed:
self._initialize_pipe(device, offload_device)
self._initialize_pipe(device, offload_device)
self.lpw = lpw
self._loras = []

Expand Down Expand Up @@ -164,10 +162,11 @@ def _initialize_pipe(self, device, offload_device):
self.pipe.vae.enable_tiling()
# --- the best one and seems to be enough ---
# self.pipe.enable_sequential_cpu_offload()
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
if self.model_type == ModelType.FLUX:
if offload_device is not None:
self.pipe.enable_sequential_cpu_offload(offload_device)
logging.debug(f'enable_sequential_cpu_offload for pipe dtype {self.pipe.dtype}')
pass
else:
try:
import xformers
Expand Down Expand Up @@ -409,6 +408,7 @@ def gen(self, inputs: dict):
generated image
"""
kwargs = self.prepare_inputs(inputs)
logging.debug("Prompt2ImPipe.gen calling pipe")
image = self.pipe(**kwargs).images[0]
return image

Expand Down
27 changes: 19 additions & 8 deletions multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import concurrent
from queue import Empty
import PIL

from .worker_base import ServiceThreadBase
from .prompting import Cfgen
Expand Down Expand Up @@ -62,7 +63,7 @@ def _get_pipeline(self, pipe_class, model_id, model_type, cnet=None):
cls = pipe_class._classflux
if device.type == 'cuda':
offload_device = device.index
device = torch.device('cpu', 0)
device = torch.device('cpu')
else:
cls = pipe_class._class
pipeline = self._loader.load_pipeline(cls, model_id, torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -110,8 +111,10 @@ def _update(sess, job, gs):
device = None
# keep the job in the queue until complete
try:
sess = data.get('session', None)
session_id = data["session_id"]
sess = self.sessions[session_id]
if sess is None:
sess = self.sessions[session_id]
sess['status'] ='running'
self.logger.info("GENERATING: " + str(data))
if 'start_callback' in data:
Expand All @@ -132,7 +135,9 @@ def _update(sess, job, gs):
raise RuntimeError(f"unexpected model type {mt}")
pipe = self.get_pipeline(pipe_name, model_id, model_type, cnet=data.get('cnet', None))
device = pipe.pipe.device
offload_device = pipe.offload_gpu_id
offload_device = None
if hasattr(pipe, 'offload_gpu_id'):
offload_device = pipe.offload_gpu_id
self.logger.debug(f'running job on {device} offload {offload_device}')
if device.type in ['cuda', 'meta']:
with self._lock:
Expand All @@ -143,20 +148,26 @@ def _update(sess, job, gs):
class_name = str(pipe.__class__)
self.logger.debug(f'got pipeline {class_name}')

images = data['images']
if 'MaskedIm2ImPipe' in class_name:
images = data.get('images', None)
if images and 'MaskedIm2ImPipe' in class_name:
pipe.setup(**data, original_image=str(images[0]),
image_painted=str(images[1]))
elif any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
pipe.setup(**data, fimage=str(images[0]))
elif images and any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
if isinstance(images[0], PIL.Image.Image):
pipe.setup(**data, fimage=None, image=images[0])
else:
pipe.setup(**data, fimage=str(images[0]))
else:
pipe.setup(**data)
# TODO: add negative prompt to parameters
nprompt_default = "jpeg artifacts, blur, distortion, watermark, signature, extra fingers, fewer fingers, lowres, nude, bad hands, duplicate heads, bad anatomy, bad crop"
nprompt = data.get('nprompt', nprompt_default)
seeds = data.get('seeds', None)
self.logger.debug(f"offload_device {pipe.offload_gpu_id}")
gs = GenSession(self.get_image_pathname(data["session_id"], None),
directory = data.get('gen_dir', None)
if directory is None:
directory = self.get_image_pathname(data["session_id"], None)
gs = GenSession(directory,
pipe, Cfgen(data["prompt"], nprompt, seeds=seeds))
gs.gen_sess(add_count = data["count"],
callback = lambda: _update(sess, data, gs))
Expand Down
7 changes: 1 addition & 6 deletions multigen/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml
from pathlib import Path
import logging
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe, InpaintingPipe
from .loader import Loader


Expand Down Expand Up @@ -95,11 +95,6 @@ def close_session(self, session_id):
def queue_gen(self, **args):
self.logger.info("REQUESTED FOR QUEUE: " + str(args))
with self._lock:
if args["session_id"] not in self.sessions:
return { "error": "Session is not open" }
# for q in self.queue:
# if q["session_id"] == args["session_id"]:
# return { "error": "The job for this session already exists" }
a = {**args}
a["count"] = int(a["count"])
if a["count"] <= 0:
Expand Down
24 changes: 0 additions & 24 deletions tests/pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,30 +323,6 @@ def get_model(self):
def test_lpw_turned_off(self):
pass

def est_basic_txt2im(self):
model = self.get_model()
device = torch.device('cpu', 0)
# create pipe
offload = 0 if torch.cuda.is_available() else None
pipe = Prompt2ImPipe(model, pipe=self._pipeline,
model_type=self.model_type(),
device=device, offload_device=offload)
pipe.setup(width=512, height=512, guidance_scale=7, scheduler="FlowMatchEulerDiscreteScheduler", steps=5)
seed = 49045438434843
params = dict(prompt="a cube planet, cube-shaped, space photo, masterpiece",
negative_prompt="spherical",
generator=torch.Generator(device).manual_seed(seed))
image = pipe.gen(params)
image.save("cube_test.png")

# generate with different seed
params['generator'] = torch.Generator(device).manual_seed(seed + 1)
image_ddim = pipe.gen(params)
image_ddim.save("cube_test2_dimm.png")
diff = self.compute_diff(image_ddim, image)
# check that difference is large
self.assertGreater(diff, 1000)


if __name__ == '__main__':
setup_logger('test_pipe.log')
Expand Down

0 comments on commit 4892cd6

Please sign in to comment.