Skip to content

Commit

Permalink
Merge pull request #56 from noskill/main
Browse files Browse the repository at this point in the history
fix random number generator in session
  • Loading branch information
Necr0x0Der authored Apr 27, 2024
2 parents 7837a30 + 3bd0acb commit 277aa89
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 32 deletions.
51 changes: 21 additions & 30 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,27 +153,11 @@ def get_config(self):
cfg.update(self.pipe_params)
return cfg

def setup(self, steps=50, **args):
def setup(self, steps=50, clip_skip=0, **args):
self.pipe_params = { 'num_inference_steps': steps }
if 'clip_skip' in args:
# TODO? add clip_skip to config?
clip_skip = args['clip_skip']
assert clip_skip >= 0
assert clip_skip <= 10
if clip_skip:
prev_encoder = self.pipe.text_encoder
prev_config = prev_encoder.config
# if we need less or equal number of hidden layers
if 12 - clip_skip <= prev_config.num_hidden_layers:
config = copy.copy(prev_config)
config.num_hidden_layers = 12 - clip_skip
self.pipe.text_encoder = CLIPTextModel(config)
self.pipe.text_encoder.load_state_dict(prev_encoder.state_dict(), strict=False)
else: # we need more hidden layers
self.pipe.text_encoder = CLIPTextModel.from_pretrained(self.model_id, subfolder="text_encoder",
num_hidden_layers=12 - clip_skip)
self.pipe.text_encoder.to(prev_encoder.device)
self.pipe.text_encoder.to(prev_encoder.dtype)
assert clip_skip >= 0
assert clip_skip <= 10
self.pipe_params['clip_skip'] = clip_skip
if 'scheduler' in args:
# TODO? add scheduler to config?
self.try_set_scheduler(dict(scheduler=args['scheduler']))
Expand Down Expand Up @@ -211,12 +195,17 @@ def setup(self, width=768, height=768, guidance_scale=7.5, **args):
"guidance_scale": guidance_scale
})

def gen(self, inputs):
inputs = {**inputs}
inputs.update(self.pipe_params)
def gen(self, inputs: dict):
kwargs = self.pipe_params.copy()
# we can override pipe parameters
# so we update kwargs with inputs after pipe_params
kwargs.update(inputs)
# allow for scheduler overwrite
self.try_set_scheduler(inputs)
image = self.pipe(**inputs).images[0]
if 'clip_skip' in kwargs and 'StableDiffusionLongPromptWeightingPipeline' in str(type(self.pipe)):
# not supported
kwargs.pop('clip_skip')
image = self.pipe(**kwargs).images[0]
return image


Expand Down Expand Up @@ -250,12 +239,14 @@ def get_config(self):
cfg.update(self.pipe_params)
return cfg

def gen(self, inputs):
inputs = {**inputs}
inputs.update(self.pipe_params)
inputs.update({"image": self._input_image})
self.try_set_scheduler(inputs)
image = self.pipe(**inputs).images[0]
def gen(self, inputs: dict):
kwargs = self.pipe_params.copy()
# we can override pipe parameters
# so we update kwargs with inputs after pipe_params
kwargs.update({"image": self._input_image})
kwargs.update(inputs)
self.try_set_scheduler(kwargs)
image = self.pipe(**kwargs).images[0]
return image


Expand Down
4 changes: 2 additions & 2 deletions multigen/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def gen_sess(self, add_count = 0, save_img=True,
self.last_index = self.confg.count - 1
self.last_conf = {**inputs}
# TODO: multiple inputs?
# inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])
inputs['generator'] = torch.cuda.manual_seed(inputs['generator'])
inputs['generator'] = torch.Generator().manual_seed(inputs['generator'])

image = self.pipe.gen(inputs)
if save_img:
self.last_img_name = self.get_last_file_prefix() + ".png"
Expand Down

0 comments on commit 277aa89

Please sign in to comment.