diff --git a/multigen/__init__.py b/multigen/__init__.py index 6c397eb..4b93c5b 100755 --- a/multigen/__init__.py +++ b/multigen/__init__.py @@ -1,4 +1,4 @@ -from .pipes import Prompt2ImPipe, Im2ImPipe, MaskedIm2ImPipe, CIm2ImPipe +from .pipes import Prompt2ImPipe, Im2ImPipe, MaskedIm2ImPipe, CIm2ImPipe, Cond2ImPipe from .sessions import GenSession from .prompting import Cfgen from .loader import Loader diff --git a/multigen/pipes.py b/multigen/pipes.py index 396e561..a8a0145 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -426,7 +426,7 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionImg2ImgPipeline] = No self._input_image = None def setup(self, fimage, image=None, strength=0.75, - guidance_scale=7.5, scale=None, timestep_spacing='linspace', **args): + guidance_scale=7.5, scale=None, timestep_spacing='linspace', width=None, height=None, **args): """ Setup pipeline for generation. @@ -445,8 +445,11 @@ def setup(self, fimage, image=None, strength=0.75, self._input_image = Image.open(fimage).convert("RGB") if image is None else image self._input_image = self.scale_image(self._input_image, scale) self._original_size = self._input_image.size + logging.debug("origin image size {self._original_size}") self._input_image = util.pad_image_to_multiple_of_8(self._input_image) self.pipe_params.update({ + "width": self._input_image.width if width is None else width, + "height": self._input_image.height if height is None else height, "strength": strength, "guidance_scale": guidance_scale }) @@ -499,6 +502,7 @@ def gen(self, inputs: dict): kwargs.update({"image": self._input_image}) self.try_set_scheduler(kwargs) image = self.pipe(**kwargs).images[0] + logging.debug(f'generated image {image}') result = image.crop((0, 0, self._original_size[0], self._original_size[1])) return result @@ -858,6 +862,7 @@ def setup(self, fimage, width=None, height=None, self.fname = fimage image = Image.open(fimage).convert("RGB") if image is None else image self._original_size = image.size + self._use_input_size = width is None or height is None image = util.pad_image_to_multiple_of_8(image) self._condition_image = [image] self._input_image = [image] @@ -910,7 +915,8 @@ def gen(self, inputs): inputs.update({"image": self._input_image, "control_image": self._condition_image}) image = self.pipe(**inputs).images[0] - result = image.crop((0, 0, self._original_size[0], self._original_size[1])) + result = image.crop((0, 0, self._original_size[0] if self._use_input_size else inputs.get('height'), + self._original_size[1] if self._use_input_size else inputs.get('width') )) return result diff --git a/tests/pipe_test.py b/tests/pipe_test.py index d4f6bd2..71af63d 100644 --- a/tests/pipe_test.py +++ b/tests/pipe_test.py @@ -7,7 +7,7 @@ import numpy from PIL import Image -from multigen import Prompt2ImPipe, Im2ImPipe, Cfgen, GenSession, Loader, MaskedIm2ImPipe, CIm2ImPipe +from multigen import Prompt2ImPipe, Im2ImPipe, Cond2ImPipe, Cfgen, GenSession, Loader, MaskedIm2ImPipe, CIm2ImPipe from multigen.log import setup_logger from multigen.pipes import ModelType from dummy import DummyDiffusionPipeline @@ -210,6 +210,21 @@ def test_controlnet(self): diff = self.compute_diff(image_ddim, image) # check that difference is large self.assertGreater(diff, 1000) + + def test_cond2im(self): + model = self.get_model() + model_type = self.model_type() + pipe = Cond2ImPipe(model, ctypes=["pose"], model_type=model_type) + pipe.setup("./pose6.jpeg", width=768, height=768) + seed = 49045438434843 + params = dict(prompt="child in the coat playing in sandbox", + negative_prompt="spherical", + generator=torch.Generator().manual_seed(seed)) + img = pipe.gen(params) + self.assertEqual(img.size, (768, 768)) + pipe.setup("./pose6.jpeg") + img1 = pipe.gen(params) + self.assertEqual(img1.size, (450, 450)) class TestSDXL(MyTestCase): @@ -245,6 +260,10 @@ def get_model(self): def test_lpw_turned_off(self): pass + @unittest.skip('not implemented yet') + def test_cond2im(self): + pass + if __name__ == '__main__': setup_logger('test_pipe.log') diff --git a/tests/pose6.jpeg b/tests/pose6.jpeg new file mode 120000 index 0000000..4220702 --- /dev/null +++ b/tests/pose6.jpeg @@ -0,0 +1 @@ +../examples/pose6.jpeg \ No newline at end of file