diff --git a/multigen/pipes.py b/multigen/pipes.py index 0a9e522..09dc986 100755 --- a/multigen/pipes.py +++ b/multigen/pipes.py @@ -153,27 +153,11 @@ def get_config(self): cfg.update(self.pipe_params) return cfg - def setup(self, steps=50, **args): + def setup(self, steps=50, clip_skip=0, **args): self.pipe_params = { 'num_inference_steps': steps } - if 'clip_skip' in args: - # TODO? add clip_skip to config? - clip_skip = args['clip_skip'] - assert clip_skip >= 0 - assert clip_skip <= 10 - if clip_skip: - prev_encoder = self.pipe.text_encoder - prev_config = prev_encoder.config - # if we need less or equal number of hidden layers - if 12 - clip_skip <= prev_config.num_hidden_layers: - config = copy.copy(prev_config) - config.num_hidden_layers = 12 - clip_skip - self.pipe.text_encoder = CLIPTextModel(config) - self.pipe.text_encoder.load_state_dict(prev_encoder.state_dict(), strict=False) - else: # we need more hidden layers - self.pipe.text_encoder = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder", - num_hidden_layers=12 - clip_skip) - self.pipe.text_encoder.to(prev_encoder.device) - self.pipe.text_encoder.to(prev_encoder.dtype) + assert clip_skip >= 0 + assert clip_skip <= 10 + self.pipe_params['clip_skip'] = clip_skip if 'scheduler' in args: # TODO? add scheduler to config? self.try_set_scheduler(dict(scheduler=args['scheduler'])) @@ -211,12 +195,17 @@ def setup(self, width=768, height=768, guidance_scale=7.5, **args): "guidance_scale": guidance_scale }) - def gen(self, inputs): - inputs = {**inputs} - inputs.update(self.pipe_params) + def gen(self, inputs: dict): + kwargs = self.pipe_params.copy() + # we can override pipe parameters + # so we update kwargs with inputs after pipe_params + kwargs.update(inputs) # allow for scheduler overwrite self.try_set_scheduler(inputs) - image = self.pipe(**inputs).images[0] + if 'clip_skip' in kwargs and 'StableDiffusionLongPromptWeightingPipeline' in str(type(self.pipe)): + # not supported + kwargs.pop('clip_skip') + image = self.pipe(**kwargs).images[0] return image @@ -250,12 +239,14 @@ def get_config(self): cfg.update(self.pipe_params) return cfg - def gen(self, inputs): - inputs = {**inputs} - inputs.update(self.pipe_params) - inputs.update({"image": self._input_image}) - self.try_set_scheduler(inputs) - image = self.pipe(**inputs).images[0] + def gen(self, inputs: dict): + kwargs = self.pipe_params.copy() + # we can override pipe parameters + # so we update kwargs with inputs after pipe_params + kwargs.update({"image": self._input_image}) + kwargs.update(inputs) + self.try_set_scheduler(kwargs) + image = self.pipe(**kwargs).images[0] return image