Skip to content

Commit

Permalink
Merge pull request #48 from noskill/imim2
Browse files Browse the repository at this point in the history
support im2im pipeline in worker
  • Loading branch information
Necr0x0Der authored Mar 31, 2024
2 parents 986f2b2 + 583a5e6 commit 549617b
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions multigen/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Type
import torch
import logging
from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline

Expand All @@ -12,7 +13,7 @@ class for loading diffusion pipelines from files.
def __init__(self):
self._pipes = dict()

def load_pipeline(self, cls: Type[DiffusionPipeline], path, **additional_args):
def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.float16, **additional_args):
for key, pipe in self._pipes.items():
if key == path:
components = pipe.components
Expand All @@ -25,7 +26,7 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, **additional_args):
# but we don't need it
if 'controlnet' in components:
components.pop('controlnet')
return cls(**components, **additional_args)
return cls(**components, **additional_args).to(torch_dtype)


if path.endswith('safetensors'):
Expand Down
2 changes: 1 addition & 1 deletion multigen/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _update(sess, job, gs):
if 'MaskedIm2ImPipe' in class_name:
pipe.setup(**data, original_image=str(images[0]),
image_painted=str(images[1]))
elif 'Cond2ImPipe' in class_name:
elif any([x in class_name for x in ('Im2ImPipe', 'Cond2ImPipe')]):
pipe.setup(**data, fimage=str(images[0]))
else:
pipe.setup(**data)
Expand Down
2 changes: 1 addition & 1 deletion multigen/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import yaml
from pathlib import Path
import logging
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe
from .pipes import Prompt2ImPipe, MaskedIm2ImPipe, Cond2ImPipe, Im2ImPipe
from .loader import Loader


Expand Down

0 comments on commit 549617b

Please sign in to comment.