Skip to content

Commit

Permalink
additional preprocessing for im2im pipes
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed Sep 11, 2024
1 parent b3c2b06 commit 4d4b792
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 10 deletions.
30 changes: 26 additions & 4 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .pipelines.masked_stable_diffusion_img2img import MaskedStableDiffusionImg2ImgPipeline
from .pipelines.masked_stable_diffusion_xl_img2img import MaskedStableDiffusionXLImg2ImgPipeline
from transformers import CLIPProcessor, CLIPTextModel
from . import util
#from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
# from diffusers import StableDiffusionKDiffusionPipeline

Expand Down Expand Up @@ -311,7 +312,7 @@ def prepare_inputs(self, inputs):
if 'negative_prompt' in kwargs:
kwargs.pop('negative_prompt')
logging.warning('negative prompt is not supported by flux!')

if self.lpw:
kwargs.setdefault('negative_prompt', None)
kwargs.setdefault('clip_skip', None)
Expand Down Expand Up @@ -428,6 +429,8 @@ def setup(self, fimage, image=None, strength=0.75,
self.fname = fimage
self._input_image = Image.open(fimage).convert("RGB") if image is None else image
self._input_image = self.scale_image(self._input_image, scale)
self._original_size = self._input_image.size
self._input_image = util.pad_image_to_multiple_of_8(self._input_image)
self.pipe_params.update({
"strength": strength,
"guidance_scale": guidance_scale
Expand Down Expand Up @@ -481,7 +484,8 @@ def gen(self, inputs: dict):
kwargs.update({"image": self._input_image})
self.try_set_scheduler(kwargs)
image = self.pipe(**kwargs).images[0]
return image
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
return result


class MaskedIm2ImPipe(Im2ImPipe):
Expand Down Expand Up @@ -550,12 +554,24 @@ def setup(self, original_image=None, image_painted=None, mask=None, blur=4,
self._original_image = Image.open(original_image) if isinstance(original_image, str) else original_image
self._image_painted = Image.open(image_painted) if isinstance(image_painted, str) else image_painted

if self._original_image.mode == 'RGBA':
self._original_image = self._original_image.convert("RGB")

if self._image_painted is not None:
if not isinstance(self._image_painted, Image.Image):
self._image_painted = Image.fromarray(self._image_painted)
self._image_painted = self._image_painted.convert("RGB")

input_image = self._image_painted if self._image_painted is not None else self._original_image

super().setup(fimage=None, image=input_image, scale=scale, **kwargs)
if self._original_image is not None:
self._original_image = self.scale_image(self._original_image, scale)
self._original_image = util.pad_image_to_multiple_of_8(self._original_image)
if self._image_painted is not None:
self._image_painted = self.scale_image(self._image_painted, scale)
self._image_painted = util.pad_image_to_multiple_of_8(self._image_painted)

# there are two options:
# 1. mask is provided
# 2. mask is computed from difference between original_image and image_painted
Expand All @@ -571,8 +587,11 @@ def setup(self, original_image=None, image_painted=None, mask=None, blur=4,
pil_mask = Image.fromarray(mask)
if pil_mask.mode != "L":
pil_mask = pil_mask.convert("L")
pil_mask = util.pad_image_to_multiple_of_8(pil_mask)
self._mask = pil_mask
self._mask_blur = self.blur_mask(pil_mask, blur)
self._mask_compose = self.blur_mask(pil_mask, blur_compose)
self._mask_compose = self.blur_mask(pil_mask.crop((0, 0, self._original_size[0], self._original_size[1]))
, blur_compose)
self._sample_mode = sample_mode

def blur_mask(self, pil_mask, blur):
Expand All @@ -592,7 +611,7 @@ def gen(self, inputs):
img_gen = super().gen(inputs)

# compose with original using mask
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image
img_compose = self._mask_compose * img_gen + (1 - self._mask_compose) * self._original_image.crop((0, 0, self._original_size[0], self._original_size[1]))
# convert to PIL image
img_compose = Image.fromarray(img_compose.astype(np.uint8))
return img_compose
Expand Down Expand Up @@ -753,6 +772,8 @@ def setup(self, fimage, width=None, height=None,
# TODO: allow multiple input images for multiple control nets
self.fname = fimage
image = Image.open(fimage).convert("RGB") if image is None else image
self._original_size = image.size
image = util.pad_image_to_multiple_of_8(image)
self._condition_image = [image]
self._input_image = [image]
if cscales is None:
Expand Down Expand Up @@ -798,6 +819,7 @@ def gen(self, inputs):
inputs.update({"image": self._input_image,
"control_image": self._condition_image})
image = self.pipe(**inputs).images[0]
result = image.crop((0, 0, self._original_size[0], self._original_size[1]))
return image


Expand Down
36 changes: 36 additions & 0 deletions multigen/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,39 @@ def create_exif_metadata(im: Image, custom_metadata):
exif[0x9286] = custom_metadata_bytes
# there is no api to set exif in loaded image it seems
return exif


def pad_image_to_multiple_of_8(image: Image) -> Image:
"""
Pads the input image by repeating the bottom or right-most column of pixels
so that the height and width of the image is divisible by 8.
Args:
image (Image): The input PIL image.
Returns:
Image: The padded PIL image.
"""

# Calculate the new dimensions
new_width = (image.width + 7) // 8 * 8
new_height = (image.height + 7) // 8 * 8

# Create a new image with the new dimensions and paste the original image onto it
padded_image = Image.new(image.mode, (new_width, new_height))
padded_image.paste(image, (0, 0))

# Repeat the right-most column of pixels to fill the horizontal padding
for x in range(new_width - image.width):
box = (image.width + x, 0, image.width + x + 1, image.height)
region = image.crop((image.width - 1, 0, image.width, image.height))
padded_image.paste(region, box)

# Repeat the bottom-most row of pixels to fill the vertical padding
for y in range(new_height - image.height):
box = (0, image.height + y, image.width, image.height + y + 1)
region = image.crop((0, image.height - 1, image.width, image.height))
padded_image.paste(region, box)

return padded_image

29 changes: 23 additions & 6 deletions tests/pipe_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import unittest
import os
import logging
import shutil
import PIL
import torch
import numpy

from PIL import Image
from multigen import Prompt2ImPipe, Im2ImPipe, Cfgen, GenSession, Loader, MaskedIm2ImPipe
from multigen.log import setup_logger
from multigen.pipes import ModelType
Expand Down Expand Up @@ -40,8 +42,12 @@ def get_model(self):
return models_dir + '/icb_diffusers'
return "hf-internal-testing/tiny-stable-diffusion-torch"

def get_ref_image(self):
return "cube_planet_dms.png"
def get_ref_image(self, dw, dh):
img = Image.open("cube_planet_dms.png")
img = img.resize((img.width + dw, img.height + dh))
pth = './cube_planet_dms1.png'
img.save(pth)
return pth

def model_type(self):
return ModelType.SDXL if 'TestSDXL' in str(self.__class__) else ModelType.SD
Expand Down Expand Up @@ -121,7 +127,8 @@ def test_loader(self):

def test_img2img_basic(self):
pipe = Im2ImPipe(self.get_model(), model_type=self.model_type())
im = self.get_ref_image()
dw, dh = -1, 1
im = self.get_ref_image(dw, dh)
seed = 49045438434843
pipe.setup(im, strength=0.7, steps=5, guidance_scale=3.3)
self.assertEqual(3.3, pipe.pipe_params['guidance_scale'])
Expand All @@ -141,8 +148,14 @@ def test_img2img_basic(self):
def test_maskedimg2img_basic(self):
pipe = MaskedIm2ImPipe(self.get_model(), model_type=self.model_type())
img = PIL.Image.open("./mech_beard_sigm.png")
dw, dh = -1, -1
img = img.crop((0, 0, img.width + dw, img.height + dh))
logging.info(f'testing on image {img.size}')

# read image with mask painted over
img_paint = numpy.array(PIL.Image.open("./mech_beard_sigm_mask.png"))
img_paint = PIL.Image.open("./mech_beard_sigm_mask.png")
img_paint = img_paint.crop((0, 0, img_paint.width + dw, img_paint.height + dh))
img_paint = numpy.asarray(img_paint)

scheduler = "EulerAncestralDiscreteScheduler"
seed = 49045438434843
Expand All @@ -154,6 +167,8 @@ def test_maskedimg2img_basic(self):
pipe.setup(**param_3_3)
self.assertEqual(3.3, pipe.pipe_params['guidance_scale'])
image = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator(pipe.pipe.device).manual_seed(seed)))
self.assertEquals(image.width, img.width)
self.assertEquals(image.height, img.height)
image.save('test_img2img_basic.png')
pipe.setup(**param_7_6)
image1 = pipe.gen(dict(prompt="cube planet cartoon style", generator=torch.Generator(pipe.pipe.device).manual_seed(seed)))
Expand All @@ -165,6 +180,8 @@ def test_maskedimg2img_basic(self):
diff = self.compute_diff(image2, image)
# check that difference is small
self.assertLess(diff, 1)
self.assertEqual(image.width, img.width)
self.assertEqual(image.height, img.height)

@unittest.skipIf(not found_models(), "can't run on tiny version of SD")
def test_lpw(self):
Expand Down Expand Up @@ -263,8 +280,8 @@ def test_basic_txt2im(self):
device = torch.device('cpu', 0)
# create pipe
offload = 0 if torch.cuda.is_available() else None
pipe = Prompt2ImPipe(model, pipe=self._pipeline,
model_type=self.model_type(),
pipe = Prompt2ImPipe(model, pipe=self._pipeline,
model_type=self.model_type(),
device=device, offload_device=offload)
pipe.setup(width=512, height=512, guidance_scale=7, scheduler="FlowMatchEulerDiscreteScheduler", steps=5)
seed = 49045438434843
Expand Down

0 comments on commit 4d4b792

Please sign in to comment.