Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set height and width for img2img pipeline #92

Merged
merged 5 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion multigen/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .pipes import Prompt2ImPipe, Im2ImPipe, MaskedIm2ImPipe, CIm2ImPipe
from .pipes import Prompt2ImPipe, Im2ImPipe, MaskedIm2ImPipe, CIm2ImPipe, Cond2ImPipe
from .sessions import GenSession
from .prompting import Cfgen
from .loader import Loader
Expand Down
10 changes: 8 additions & 2 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionImg2ImgPipeline] = No
self._input_image = None

def setup(self, fimage, image=None, strength=0.75,
guidance_scale=7.5, scale=None, timestep_spacing='linspace', **args):
guidance_scale=7.5, scale=None, timestep_spacing='linspace', width=None, height=None, **args):
"""
Setup pipeline for generation.

Expand All @@ -445,8 +445,11 @@ def setup(self, fimage, image=None, strength=0.75,
self._input_image = Image.open(fimage).convert("RGB") if image is None else image
self._input_image = self.scale_image(self._input_image, scale)
self._original_size = self._input_image.size
logging.debug("origin image size {self._original_size}")
self._input_image = util.pad_image_to_multiple_of_8(self._input_image)
self.pipe_params.update({
"width": self._input_image.width if width is None else width,
"height": self._input_image.height if height is None else height,
"strength": strength,
"guidance_scale": guidance_scale
})
Expand Down Expand Up @@ -499,6 +502,7 @@ def gen(self, inputs: dict):
kwargs.update({"image": self._input_image})
self.try_set_scheduler(kwargs)
image = self.pipe(**kwargs).images[0]
logging.debug(f'generated image {image}')
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
return result

Expand Down Expand Up @@ -858,6 +862,7 @@ def setup(self, fimage, width=None, height=None,
self.fname = fimage
image = Image.open(fimage).convert("RGB") if image is None else image
self._original_size = image.size
self._use_input_size = width is None or height is None
image = util.pad_image_to_multiple_of_8(image)
self._condition_image = [image]
self._input_image = [image]
Expand Down Expand Up @@ -910,7 +915,8 @@ def gen(self, inputs):
inputs.update({"image": self._input_image,
"control_image": self._condition_image})
image = self.pipe(**inputs).images[0]
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'),
self._original_size[1] if self._use_input_size else inputs.get('width') ))
return result


Expand Down
21 changes: 20 additions & 1 deletion tests/pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy

from PIL import Image
from multigen import Prompt2ImPipe, Im2ImPipe, Cfgen, GenSession, Loader, MaskedIm2ImPipe, CIm2ImPipe
from multigen import Prompt2ImPipe, Im2ImPipe, Cond2ImPipe, Cfgen, GenSession, Loader, MaskedIm2ImPipe, CIm2ImPipe
from multigen.log import setup_logger
from multigen.pipes import ModelType
from dummy import DummyDiffusionPipeline
Expand Down Expand Up @@ -210,6 +210,21 @@ def test_controlnet(self):
diff = self.compute_diff(image_ddim, image)
# check that difference is large
self.assertGreater(diff, 1000)

def test_cond2im(self):
model = self.get_model()
model_type = self.model_type()
pipe = Cond2ImPipe(model, ctypes=["pose"], model_type=model_type)
pipe.setup("./pose6.jpeg", width=768, height=768)
seed = 49045438434843
params = dict(prompt="child in the coat playing in sandbox",
negative_prompt="spherical",
generator=torch.Generator().manual_seed(seed))
img = pipe.gen(params)
self.assertEqual(img.size, (768, 768))
pipe.setup("./pose6.jpeg")
img1 = pipe.gen(params)
self.assertEqual(img1.size, (450, 450))


class TestSDXL(MyTestCase):
Expand Down Expand Up @@ -245,6 +260,10 @@ def get_model(self):
def test_lpw_turned_off(self):
pass

@unittest.skip('not implemented yet')
def test_cond2im(self):
pass


if __name__ == '__main__':
setup_logger('test_pipe.log')
Expand Down
1 change: 1 addition & 0 deletions tests/pose6.jpeg