Skip to content

Commit

Permalink
delete --yolo_discretization and --discretization in favour of constr…
Browse files Browse the repository at this point in the history
…ucting CompVisDenoiser with quantize=True. this means we don't need a custom fork of k-diffusion (except for for MPS fixes). only downside compared to my original approach is that we cannot set churn>0 (see crowsonkb/k-diffusion#23 (comment)), but we never used that. I reckon the ability to quantize sigma_hat will be added to mainline k-diffusion eventually (discussing here: crowsonkb/k-diffusion#23 (comment)), so think it's best to keep the k-diffusion branch free of bespoke changes (with the exception of MPS), to keep it easy to rebase onto mainline. remove ability to opt in/out of discretization, now that I've finished comparing them (crowsonkb/k-diffusion#23) -- the difference is barely perceptible but discretization is the better choice in theory.
  • Loading branch information
Birch-san committed Sep 3, 2022
1 parent 50476fa commit f52429a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 94 deletions.
42 changes: 22 additions & 20 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,11 @@
"console": "integratedTerminal",
"justMyCode": true,
"args": [
// "--prompt",
"--prompt",
// "Kunkka enjoys a nice coffee with Crystal Maiden. pixiv",
// "masterpiece character portrait of a blonde girl, full resolution, 4 k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination",
"masterpiece character portrait of a blonde girl, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave",
// "masterpiece character portrait of a blonde girl, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave",
"masterpiece character portrait of a girl, yukata, full resolution, 4k, artgerm, wlop, sarrailh, kuvshinov, global illumination, vaporwave, neon night",
// "masterpiece character portrait of a beautiful shrine maiden with long black hair, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave",
// "masterpiece character portrait of a shrine maiden with long black hair, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave",
// "masterpiece character portrait of a shrine maiden, full resolution, 4k, mizuryuu kei, akihiko. yoshida, Pixiv featured, baroque scenic, by artgerm, sylvain sarrailh, rossdraws, wlop, global illumination, vaporwave",
Expand All @@ -74,30 +75,31 @@
// "--filename_sample_ix",
// "--filename_seed",
"--filename_sampling",
// "--filename_sigmas",
"--filename_sigmas",
"--n_iter", "1",
"--n_samples", "1",
"--steps", "8",
"--steps", "50",
// "--steps", "50",
// "--sampler", "plms",
"--steps", "7",
"--sampler", "heun",
"--karras_noise",
"--discretization",
"--no_zero_sigma",
"--end_karras_ramp_early",
"--seed",
"1396704121",
"--scale",
"20",
"--dynamic_thresholding",
"--dynamic_thresholding_percentile",
"0.9"
// "1396704121",
"68673924",
// "--scale",
// "20",
// "--dynamic_thresholding",
// "--dynamic_thresholding_percentile",
// "0.9",
// "--fixed_code",
"--init_img",
"/Users/birch/dall-e-mega/fumo plush of lina from dota 1.jpg",
"--prompt",
"plush doll maiden of magic fire",
"--strength", "0.3",
"--f",
"16"
// "--init_img",
// "/Users/birch/dall-e-mega/fumo plush of lina from dota 1.jpg",
// "--prompt",
// "plush doll maiden of magic fire",
// "--strength", "0.3",
// "--f",
// "16"
]
},
{
Expand Down
105 changes: 31 additions & 74 deletions scripts/txt2img_fork.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@
from torch import autocast, nn
from contextlib import contextmanager, nullcontext
from random import randint
from typing import Optional

from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler

from k_diffusion.sampling import sample_lms, sample_dpm_2, sample_dpm_2_ancestral, sample_euler, sample_euler_ancestral, sample_heun, get_sigmas_karras, make_quantizer
from k_diffusion.sampling import sample_lms, sample_dpm_2, sample_dpm_2_ancestral, sample_euler, sample_euler_ancestral, sample_heun, get_sigmas_karras, append_zero
from k_diffusion.external import CompVisDenoiser

def get_device():
Expand Down Expand Up @@ -191,12 +192,12 @@ def main():
default="k_lms"
)
# my recommendations for each sampler are:
# implement samplers from Karras et al paper using Karras noise schedule, discretize timesteps.
# if your step count is low (for example 7 or 8) you should use add ----end_karras_ramp_early too.
# --heun --karras_noise --discretization
# --euler --karras_noise --discretization
# --dpm2 --karras_noise --discretization
# I assume Karras noise schedule is generally applicable, so is suitable for use with any k-diffusion sampler. there's no guidance on how to apply discretization to these algorithms, so we don't.
# implement samplers from Karras et al paper using Karras noise schedule
# if your step count is low (for example 7 or 8) you should use add --end_karras_ramp_early too.
# --heun --karras_noise
# --euler --karras_noise
# --dpm2 --karras_noise
# I assume Karras noise schedule is generally applicable, so is suitable for use with any k-diffusion sampler.
# --k_lms --karras_noise
# --euler_ancestral --karras_noise
# --dpm2_ancestral --karras_noise
Expand All @@ -208,21 +209,11 @@ def main():
action='store_true',
help=f"use noise schedule from arXiv:2206.00364. Implemented for k-diffusion samplers, {K_DIFF_SAMPLERS}. but you should probably use it with one of the samplers introduced in the same paper: {KARRAS_SAMPLERS}.",
)
parser.add_argument(
"--discretization",
action='store_true',
help=f"implements the time-step discretization from arXiv:2206.00364 section C.3.4. Implemented in Karras samplers only, {KARRAS_SAMPLERS}. Rounds each sigma proposed by your noise schedule, to the closest sigma among the 1000 on which Stable Diffusion's DDIM sampler was trained.",
)
parser.add_argument(
"--end_karras_ramp_early",
action='store_true',
help=f"when --karras_noise is enabled: ramp from sigma_max (14.6146) to a sigma *slightly above* sigma_min (0.0292), instead of including sigma_min in our ramp. because the effect on the image of sampling sigma_min is not very big, and every sigma counts when our step counts are low. use this to get good results with {KARRAS_SAMPLERS} at step counts as low as 7 or 8.",
)
parser.add_argument(
"--yolo_discretization",
action='store_true',
help=f"if you are using a non-Karras k-diffusion sampler, {NON_KARRAS_K_DIFF_SAMPLERS}, and you're feeling spicy: rounds your noise schedule to the DDIM's nearest sigma before we give it to the sampler. this might be a crazy thing to do, or it might actually help.",
)
parser.add_argument(
"--dynamic_thresholding",
action='store_true',
Expand Down Expand Up @@ -356,11 +347,6 @@ def main():
action='store_true',
help="include sigmas in file name",
)
parser.add_argument(
"--print_sigma_hats",
action='store_true',
help=f"print sigmas before/after sampler quantizes them (for {KARRAS_SAMPLERS} samplers)",
)
parser.add_argument(
"--init_img",
type=str,
Expand Down Expand Up @@ -390,7 +376,7 @@ def main():
model = model.to(device)

if opt.sampler in K_DIFF_SAMPLERS:
model_k_wrapped = CompVisDenoiser(model)
model_k_wrapped = CompVisDenoiser(model, quantize=True)
model_k_config = KCFGDenoiser(model_k_wrapped)
elif opt.sampler in NOT_K_DIFF_SAMPLERS:
if opt.sampler == 'plms':
Expand Down Expand Up @@ -430,12 +416,17 @@ def main():
start_code = None

karras_noise_active = False
discretization_active = False
yolo_discretization_active = False
end_karras_ramp_early_active = False

def format_sigmas_pretty(sigmas: Tensor) -> str:
return f'[{", ".join("%.4f" % sigma for sigma in sigmas)}]'
def _format_sigma_pretty(sigma: Tensor) -> str:
return "%.4f" % sigma

def format_sigmas_pretty(sigmas: Tensor, summary: bool = False) -> str:
if (summary and sigmas.size(dim=0) > 9):
start = ", ".join(_format_sigma_pretty(sigma) for sigma in sigmas[0:4])
end = ", ".join(_format_sigma_pretty(sigma) for sigma in sigmas[-4:])
return f'[{start}, …, {end}]'
return f'[{", ".join(_format_sigma_pretty(sigma) for sigma in sigmas)}]'

def _compute_common_file_name_portion(sample_ix: str = '', sigmas: str = '') -> str:
seed = ''
Expand All @@ -446,17 +437,15 @@ def _compute_common_file_name_portion(sample_ix: str = '', sigmas: str = '') ->
guidance = ''
if opt.filename_sampling:
kna = '_kns' if karras_noise_active else ''
da = '_dcrt' if discretization_active else ''
yda = '_ydcrt' if yolo_discretization_active else ''
nz = '_ek' if end_karras_ramp_early_active else ''
sampling = f"{opt.sampler}{opt.steps}{kna}{da}{yda}{nz}"
sampling = f"{opt.sampler}{opt.steps}{kna}{nz}"
if opt.filename_seed:
seed = f".s{opt.seed}"
if opt.filename_prompt:
prompt = f"_{opt.prompt}_"
if opt.filename_sample_ix:
sample_ix_ = sample_ix
if opt.filename_sigmas:
if opt.filename_sigmas and sigmas is not None:
sigmas_ = f"_{sigmas}_"
if opt.filename_guidance:
guidance = f"_str{opt.strength}_sca{opt.scale}"
Expand All @@ -466,7 +455,7 @@ def compute_batch_file_name(sigmas: str = '') -> str:
common_file_name_portion = _compute_common_file_name_portion(sigmas=sigmas)
return f"grid-{grid_count:04}{common_file_name_portion}.png"

def compute_sample_file_name(batch_ix: int, sample_ix_in_batch: int, sigmas: str = '') -> str:
def compute_sample_file_name(batch_ix: int, sample_ix_in_batch: int, sigmas: Optional[str] = None) -> str:
sample_ix=f".n{batch_ix}.i{sample_ix_in_batch}"
common_file_name_portion = _compute_common_file_name_portion(sample_ix=sample_ix, sigmas=sigmas)
return f"{base_count:05}{common_file_name_portion}.png"
Expand Down Expand Up @@ -507,10 +496,6 @@ def compute_sample_file_name(batch_ix: int, sample_ix_in_batch: int, sigmas: str
if opt.sampler in NOT_K_DIFF_SAMPLERS:
if opt.karras_noise:
print(f"[WARN] You have requested --karras_noise, but Karras et al noise schedule is not implemented for {opt.sampler} sampler. Implemented only for {K_DIFF_SAMPLERS}. Using default noise schedule from DDIM.")
if opt.discretization:
print(f"[WARN] You have requested --discretization, but time-step discretization is not implemented for {opt.sampler} sampler. Implemented only for {KARRAS_SAMPLERS}. Using the sigmas from DDIM without applying any rounding.")
if opt.yolo_discretization:
print(f"[WARN] You have requested --yolo_discretization, but questionable time-step discretization is not implemented for {opt.sampler} sampler. Implemented only for {NON_KARRAS_K_DIFF_SAMPLERS}. Using the sigmas from DDIM without applying any rounding.")
if init_latent is None:
samples, _ = sampler.sample(
S=opt.steps,
Expand All @@ -523,6 +508,9 @@ def compute_sample_file_name(batch_ix: int, sample_ix_in_batch: int, sigmas: str
eta=opt.ddim_eta,
x_T=start_code
)
# for PLMS and DDIM, sigmas are all 0
sigmas = None
sigmas_quantized = None
else:
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
samples = sampler.decode(
Expand Down Expand Up @@ -601,38 +589,11 @@ def get_premature_sigma_min(
if init_latent is not None:
sigmas = sigmas[len(sigmas) - t_enc - 1 :]

print('sigmas:')
sigmas_pretty = format_sigmas_pretty(sigmas)
print(sigmas_pretty)

if opt.sampler in KARRAS_SAMPLERS:
if opt.discretization:
noise_schedule_sampler_args['decorate_sigma_hat'] = make_quantizer(model_k_wrapped.sigmas)
discretization_active = True
else:
print(f"[WARN] you should really enable --discretization to get the full benefits of your Karras sampler, {opt.sampler}. time step discretization implements section C.3.4 of arXiv:2206.00364, the paper in which your sampler was proposed.")
if opt.yolo_discretization:
print(f"[WARN] You have requested --yolo_discretization, but we will ignore this because your {opt.sampler} sampler supports the superior --discretization.")
else:
if opt.discretization:
print(f"[WARN] You have requested --discretization, but time-step discretization is not implemented for {opt.sampler} sampler. Implemented only for {KARRAS_SAMPLERS}.")
if opt.yolo_discretization:
print(f"[WARN] you are using the experimental YOLO time-step discretization, which is experimental. if you are lucky, perhaps it will approximate section C.3.4 of arXiv:2206.00364 for your {opt.sampler} sampler.")
# quantize sigmas from noise schedule to closest equivalent in model_k_wrapped.sigmas
sigmas = model_k_wrapped.sigmas[torch.argmin((sigmas.reshape(len(sigmas), 1).repeat(1, len(model_k_wrapped.sigmas)) - model_k_wrapped.sigmas).abs(), dim=1)]
yolo_discretization_active = True

if opt.print_sigma_hats:
orig_decorator = noise_schedule_sampler_args['decorate_sigma_hat'] if 'decorate_sigma_hat' in noise_schedule_sampler_args else lambda x: x
sigmas_before = []
sigmas_after = []
def log_sigma(t: Tensor) -> Tensor:
sigmas_before.append(t.item())
decorated = orig_decorator(t)
sigmas_after.append(decorated.item())
return decorated

noise_schedule_sampler_args['decorate_sigma_hat'] = log_sigma
print('sigmas (before quantization):')
print(format_sigmas_pretty(sigmas))
print('sigmas (after quantization):')
sigmas_quantized = append_zero(model_k_wrapped.sigmas[torch.argmin((sigmas[:-1].reshape(len(sigmas)-1, 1).repeat(1, len(model_k_wrapped.sigmas)) - model_k_wrapped.sigmas).abs(), dim=1)])
print(format_sigmas_pretty(sigmas_quantized))

x = start_code * sigmas[0] # for GPU draw
if init_latent is not None:
Expand Down Expand Up @@ -691,18 +652,14 @@ def log_sigma(t: Tensor) -> Tensor:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
img = Image.fromarray(x_sample.astype(np.uint8))
# img = put_watermark(img, wm_encoder)
img.save(os.path.join(sample_path, compute_sample_file_name(batch_ix=n, sample_ix_in_batch=ix, sigmas=sigmas_pretty)))
preferred_sigmas = sigmas_quantized if sigmas_quantized is not None else sigmas
img.save(os.path.join(sample_path, compute_sample_file_name(batch_ix=n, sample_ix_in_batch=ix, sigmas=format_sigmas_pretty(preferred_sigmas, summary=True) if preferred_sigmas is not None else None)))
base_count += 1

if not opt.skip_grid:
all_samples.append(x_checked_image_torch)
iter_toc = time.perf_counter()
print(f'batch {n} generated {batch_size} images in {iter_toc-iter_tic} seconds')
if opt.print_sigma_hats:
print('sigmas before:')
print(format_sigmas_pretty(sigmas_before))
print('sigmas after:')
print(format_sigmas_pretty(sigmas_after))
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
Expand Down

0 comments on commit f52429a

Please sign in to comment.