diff --git a/.github/workflows/python-app.yml b/.github/workflows/python-app.yml index dc988ce..8d7e83d 100644 --- a/.github/workflows/python-app.yml +++ b/.github/workflows/python-app.yml @@ -27,7 +27,11 @@ jobs: run: | python -m pip install --upgrade pip if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: symlink models directory + run: ln -s ../../../../models tests/models-full - name: Test with pytest + env: + METAFUSION_MODELS_DIR: models-full run: | cd tests && python pipe_test.py - name: Test worker diff --git a/multigen/loader.py b/multigen/loader.py index 766ebc0..bbcdaa8 100644 --- a/multigen/loader.py +++ b/multigen/loader.py @@ -63,7 +63,7 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.fl logger.debug(f'looking for pipeline {cls} from {path} on {device}') result = None if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', 0) if device.type == 'cuda': idx = device.index gpu_pipes = self._gpu_pipes.get(idx, []) @@ -89,7 +89,9 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.fl result = result.to(dtype=torch_dtype, device=device) self.cache_pipeline(result, path) result = copy_pipe(result) - assert result.device == device + assert result.device.type == device.type + if device.type == 'cuda': + assert result.device.index == device.index logger.debug(f'returning {type(result)} from {path} on {result.device}') return result @@ -131,6 +133,7 @@ def clear_cache(self, device): def _store_gpu_pipe(self, pipe, model_id): idx = pipe.device.index + assert idx is not None # for now just clear all other pipelines self._gpu_pipes[idx] = [(model_id, pipe)] diff --git a/multigen/lpw_stable_diffusion.py b/multigen/lpw_stable_diffusion.py index 05102bf..30bd6d8 100644 --- a/multigen/lpw_stable_diffusion.py +++ b/multigen/lpw_stable_diffusion.py @@ -22,6 +22,14 @@ logging, ) from diffusers.utils.torch_utils import randn_tensor +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) # ------------------------------------------------------------------------------ @@ -261,6 +269,7 @@ def get_weighted_text_embeddings( skip_parsing: Optional[bool] = False, skip_weighting: Optional[bool] = False, clip_skip=None, + lora_scale=None, ): r""" Prompts can be assigned with local weights using brackets. For example, @@ -287,6 +296,16 @@ def get_weighted_text_embeddings( skip_weighting (`bool`, *optional*, defaults to `False`): Skip the weighting. When the parsing is skipped, it is forced True. """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(pipe, LoraLoaderMixin): + pipe._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + else: + scale_lora_layers(pipe.text_encoder, lora_scale) max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2 if isinstance(prompt, str): prompt = [prompt] @@ -383,6 +402,11 @@ def get_weighted_text_embeddings( current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype) uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) + if pipe.text_encoder is not None: + if isinstance(pipe, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(pipe.text_encoder, lora_scale) + if uncond_prompt is not None: return text_embeddings, uncond_embeddings return text_embeddings, None diff --git a/multigen/lpw_stable_diffusion_xl.py b/multigen/lpw_stable_diffusion_xl.py index 0fb4952..a2a6c1a 100644 --- a/multigen/lpw_stable_diffusion_xl.py +++ b/multigen/lpw_stable_diffusion_xl.py @@ -22,7 +22,7 @@ from diffusers import DiffusionPipeline, StableDiffusionXLPipeline from diffusers.image_processor import PipelineImageInput, VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.loaders import StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from diffusers.pipelines.pipeline_utils import StableDiffusionMixin @@ -37,7 +37,14 @@ replace_example_docstring, ) from diffusers.utils.torch_utils import randn_tensor - +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) if is_invisible_watermark_available(): from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker @@ -256,6 +263,7 @@ def get_weighted_text_embeddings_sdxl( num_images_per_prompt: int = 1, device: Optional[torch.device] = None, clip_skip: Optional[int] = None, + lora_scale: Optional[int] = None ): """ This function can process long prompt with weights, no length limitation @@ -276,6 +284,24 @@ def get_weighted_text_embeddings_sdxl( """ device = device or pipe._execution_device + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(pipe, StableDiffusionXLLoraLoaderMixin): + pipe._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if pipe.text_encoder is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(pipe.text_encoder, lora_scale) + else: + scale_lora_layers(pipe.text_encoder, lora_scale) + + if pipe.text_encoder_2 is not None: + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(pipe.text_encoder_2, lora_scale) + else: + scale_lora_layers(pipe.text_encoder_2, lora_scale) + if prompt_2: prompt = f"{prompt} {prompt_2}" @@ -424,6 +450,16 @@ def get_weighted_text_embeddings_sdxl( bs_embed * num_images_per_prompt, -1 ) + if pipe.text_encoder is not None: + if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(pipe.text_encoder, lora_scale) + + if pipe.text_encoder_2 is not None: + if isinstance(pipe, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(pipe.text_encoder_2, lora_scale) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/multigen/pipes.py b/multigen/pipes.py index fec7d37..645b6bc 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -105,6 +105,7 @@ def __init__(self, model_id: str, self.pipe = self._load_pipeline(sd_pipe_class, model_type, args) self._initialize_pipe(device) self.lpw = lpw + self._loras = [] mt = self._get_model_type() if self.model_type is None: self.model_type = mt @@ -183,6 +184,7 @@ def load_lora(self, path, multiplier=1.0): if 'cross_attention_kwargs' not in self.pipe_params: self.pipe_params['cross_attention_kwargs'] = {} self.pipe_params['cross_attention_kwargs']["scale"] = multiplier + self._loras.append(path) def add_hypernet(self, path, multiplier=None): from . hypernet import add_hypernet, Hypernetwork @@ -209,6 +211,7 @@ def get_config(self): cfg.update({"model_id": self.model_id }) cfg['scheduler'] = dict(self.pipe.scheduler.config) cfg['scheduler']['class_name'] = self.pipe.scheduler.__class__.__name__ + cfg['loras'] = self._loras cfg.update(self.pipe_params) return cfg @@ -230,7 +233,7 @@ def setup(self, steps=50, clip_skip=0, loras=[], **args): Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. :return: None """ - self.pipe_params = { 'num_inference_steps': steps } + self.pipe_params.update({ 'num_inference_steps': steps }) assert clip_skip >= 0 assert clip_skip <= 10 self.pipe_params['clip_skip'] = clip_skip @@ -242,7 +245,7 @@ def setup(self, steps=50, clip_skip=0, loras=[], **args): for lora in loras: self.load_lora(lora) - def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] = None): + def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] = None, lora_scale: Optional[int] = None): if self.lpw: # convert to lpw if isinstance(self.pipe, self._classxl): @@ -255,6 +258,7 @@ def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] = neg_prompt=negative_prompt, num_images_per_prompt=1, clip_skip=clip_skip, + lora_scale=lora_scale ) elif isinstance(self.pipe, self._class): from . import lpw_stable_diffusion @@ -264,12 +268,14 @@ def get_prompt_embeds(self, prompt, negative_prompt, clip_skip: Optional[int] = uncond_prompt=negative_prompt, max_embeddings_multiples=3, clip_skip=clip_skip, + lora_scale=lora_scale ) def prepare_inputs(self, inputs): kwargs = self.pipe_params.copy() kwargs.update(inputs) if self.lpw: + lora_scale = kwargs.get('cross_attention_kwargs', dict()).get("scale", None) if self.model_type == ModelType.SDXL: if 'negative_prompt' not in kwargs: kwargs['negative_prompt'] = None @@ -283,7 +289,7 @@ def prepare_inputs(self, inputs): negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, - ) = self.get_prompt_embeds(kwargs.pop('prompt'), kwargs.pop('negative_prompt'), kwargs.pop('clip_skip')) + ) = self.get_prompt_embeds(kwargs.pop('prompt'), kwargs.pop('negative_prompt'), kwargs.pop('clip_skip'), lora_scale=lora_scale) kwargs['prompt_embeds'] = prompt_embeds kwargs['negative_prompt_embeds'] = negative_prompt_embeds @@ -294,7 +300,7 @@ def prepare_inputs(self, inputs): prompt=kwargs.pop('prompt'), negative_prompt=kwargs.pop('negative_prompt'), clip_skip=kwargs.pop('clip_skip'), - ) + lora_scale=lora_scale) kwargs['prompt_embeds'] = prompt_embeds kwargs['negative_prompt_embeds'] = negative_prompt_embeds else: diff --git a/multigen/prompting.py b/multigen/prompting.py index c3d9e12..90e9610 100755 --- a/multigen/prompting.py +++ b/multigen/prompting.py @@ -57,10 +57,9 @@ def __next__(self): raise StopIteration thread_data.random = random.Random() seed = self.seeds[self.count % nseeds] if nseeds > 0 else \ - thread_data.random.randint(1, 1024*1024*1024*4-1) - self.count += 1 - + random.randint(1, 1024*1024*1024*4-1) thread_data.random.seed(seed) + self.count += 1 result = {'prompt': get_prompt(self.prompt), 'generator': seed, 'negative_prompt': get_prompt(self.negative_prompt)} diff --git a/tests/pipe_test.py b/tests/pipe_test.py index 3556fb6..e397fcc 100644 --- a/tests/pipe_test.py +++ b/tests/pipe_test.py @@ -22,12 +22,21 @@ def compute_diff(self, im1: PIL.Image.Image, im2: PIL.Image.Image) -> float: +def can_run_lpw(): + if os.environ.get('METAFUSION_MODELS_DIR'): + return True + return False + + class MyTestCase(TestCase): def setUp(self): self._pipeline = None def get_model(self): + models_dir = os.environ.get('METAFUSION_MODELS_DIR', None) + if models_dir is not None: + return models_dir + '/icb_diffusers' return "hf-internal-testing/tiny-stable-diffusion-torch" def get_ref_image(self): @@ -116,11 +125,59 @@ def test_img2img_basic(self): result = pipe.gen(dict(prompt="cube planet cartoon style")) result.save('test_img2img_basic.png') + @unittest.skipIf(not can_run_lpw(), "can't run on tiny version of SD") + def test_lpw(self): + """ + Check that last part of long prompt affect the generation + """ + pipe = Prompt2ImPipe(self.get_model(), model_type=self.model_type(), lpw=True) + prompt = ' a cubic planet with atmoshere as seen from low orbit, each side of the cubic planet is ocuppied by an ocean, oceans have islands, but no continents, atmoshere of the planet has usual sperical shape, corners of the cube are above the atmoshere, but edges largely are covered by the atomosphere, there are cyclones in the atmoshere, the photo is made from low-orbit, famous sci-fi illustration' + pipe.setup(width=512, height=512, guidance_scale=7, scheduler="DPMSolverMultistepScheduler", steps=5) + seed = 49045438434843 + params = dict(prompt=prompt, + negative_prompt="spherical", + generator=torch.cuda.manual_seed(seed)) + image = pipe.gen(params) + image.save("cube_test_lpw.png") + params = dict(prompt=prompt + " , best quality, famous photo", + negative_prompt="spherical", + generator=torch.cuda.manual_seed(seed)) + image1 = pipe.gen(params) + image.save("cube_test_lpw1.png") + diff = self.compute_diff(image1, image) + # check that difference is large + self.assertGreater(diff, 1000) + + @unittest.skipIf(not can_run_lpw(), "can't run on tiny version of SD") + def test_lpw_turned_off(self): + """ + Check that last part of long prompt don't affect the generation with lpw turned off + """ + pipe = Prompt2ImPipe(self.get_model(), model_type=self.model_type(), lpw=False) + prompt = ' a cubic planet with atmoshere as seen from low orbit, each side of the cubic planet is ocuppied by an ocean, oceans have islands, but no continents, atmoshere of the planet has usual sperical shape, corners of the cube are above the atmoshere, but edges largely are covered by the atomosphere, there are cyclones in the atmoshere, the photo is made from low-orbit, famous sci-fi illustration' + pipe.setup(width=512, height=512, guidance_scale=7, scheduler="DPMSolverMultistepScheduler", steps=5) + seed = 49045438434843 + params = dict(prompt=prompt, + negative_prompt="spherical", + generator=torch.cuda.manual_seed(seed)) + image = pipe.gen(params) + image.save("cube_test_no_lpw.png") + params = dict(prompt=prompt + " , best quality, famous photo", + negative_prompt="spherical", + generator=torch.cuda.manual_seed(seed)) + image1 = pipe.gen(params) + image.save("cube_test_no_lpw1.png") + diff = self.compute_diff(image1, image) + # check that difference is large + self.assertLess(diff, 1) class TestSDXL(MyTestCase): def get_model(self): + models_dir = os.environ.get('METAFUSION_MODELS_DIR', None) + if models_dir is not None: + return models_dir + '/SDXL/stable-diffusion-xl-base-1.0' return "hf-internal-testing/tiny-stable-diffusion-xl-pipe"