-
Notifications
You must be signed in to change notification settings - Fork 385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement time step discretization for Karras samplers #23
base: master
Are you sure you want to change the base?
Conversation
(as described in the Elucidating paper arXiv:2206.00364 section C.3.4 "practical challenge" 3). Add also a way to opt-out of receiving a zero in the Karras noise schedule (makes less sense when discretizing, because 0 can be out-of-range -- i.e. lower than sigma_min -- and you'd round the result back up to sigma_min again). And for what it's worth, I moved the .to(device) to happen a little earlier in get_sigmas_karras(), on the basis that the other get_sigmas_* functions were happy to move to device *before* appending zero.
d'oh, I just noticed you already implemented a our solutions are equivalent when the paper says to quantize whereas currently k-diffusion passes into the model:
|
You need the 0 on the end so the sampler outputs a fully denoised image, the ODE needs to be integrated from sigma_max to 0 for this to happen. I think the thing you are observing happens because sigma_min (the last noise level the model is evaluated at) is too low for low step counts. Have you tried increasing sigma_min instead, but keeping the concatenation of 0? |
thanks very much @crowsonkb for explaining the importance of the 0! okay, so we need to keep the 0. but ramping all the way down to sigma_min inclusive isn't the best use of our limited sigmas. so one idea is to formalize the wacky way from which that 1.072 was computed, so we can intentionally use it as our sigma_min. the 1.072 can be obtained like this: steps=7
get_sigmas_karras(
# there's an argument that steps+1 is wacky, so let's remember to try without the +1 too
n=steps+1,
# 14.6146
sigma_max=model.sigmas[-1].item(),
# 0.0292
sigma_min=model.sigmas[0].item(),
rho=7.
)[-3] # skip nth because it's 0, skip n-1th because it's the known-bad sigma_min or more efficiently like this: # gets the N-1th sigma from a Karras noise schedule
def get_awesome_sigma_min(
steps: int,
sigma_max: float,
sigma_min_nominal: float,
rho: float
) -> float:
min_inv_rho = sigma_min_nominal ** (1 / rho)
max_inv_rho = sigma_max ** (1 / rho)
ramp = (steps-2) * 1/(steps-1)
sigma_min = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigma_min
steps=7
# 14.6146
sigma_max=model.sigmas[-1].item()
sigma_min = get_awesome_sigma_min(
steps=steps+1,
sigma_max=sigma_max,
# 0.0292
sigma_min_nominal=model.sigmas[0].item(),
rho=7.
) having computed a new sigma_min 0.1072 using we call the (unmodified) sigmas = get_sigmas_karras(
n=opt.steps,
sigma_min=sigma_min,
sigma_max=sigma_max,
rho=rho,
) it returns the following noise schedule, identical to our first experiment except ending with 0 instead of 0.0292:
the sigma_hats get discretized to these before being passed into the model:
Picture still looks good (the power of the 0.1072 sigma, probably): now let's try simplifying that ugly new sigmas out of the oven. spends more time on the middle sigmas. "more time in the middle" sounds closer to the behaviour of
the sigma_hats discretize to:
Picture still looks good: We got a new floral pattern on the sleeve! plus some new hair detail.
but overall, keeping 0 seems to make this a nicer algorithm than we started with. |
so, I think we don't need the but I still think two problems remain regarding adhering to the paper:
|
If you do model.t_to_sigma(model.sigma_to_t(sigma)) inside the sampler you can get the quantized sigma... but you can't count on those methods being there because the user could just pass in any arbitrarily wrapped model. I'm not really sure what to do tbh. |
an older factoring of the code that I tried was to expose a given that discrete sigmas are a
factoring out a common core might not be the craziest thing to do, since |
The samplers are supposed to be independent of the models, though, that would duplicate a ton of code and I might add new samplers later etc. Is there some reasonable way to guarantee that a wrapper class has all the required methods? The usual idiom here is subclassing but that doesn't really work with the wrapper idiom... |
you mean a way to sniff if we're ruling out "checking if it extends a class/mixin", then I guess that leaves "check whether it has a particular method decorated with a decorator you provide"? |
Maybe there could be a model wrapper class that has all of the methods that the samplers etc. expect, and the default implementation of these methods just forwards to the wrapped model, and users could override these methods to customize the behavior. That is, all model wrappers would subclass this and override methods, maybe just forward() but they could also alter the other methods if they did something more complicated. |
yes, that would be a good way to do it. if you're a continuous-time model, you don't want to quantize sigma_hat at all. so maybe a new base model wrapper class would be introduced (from which the base model wrapper class would have a |
I need to think about which methods to make standard on the wrapper...
Maybe forward Oh! Maybe add Maybe also have a |
disclaimer: my design patterns are based on Java experience, not Python. I'd start by only implementing stuff that you actually have a user for. I'd start from "who consumes a base class?". will end-users consume this base class? I don't know a use-case that would mean they'd ever see the base class. another consideration vis-à-vis forcing subclasses to adhere to the same method signatures… we already see some divergence here;
of the choices, I prefer
This might be another situation where — for performance reasons — it would be good to support
hmm I guess that's something I'd use (I'm currently resorting to |
…ent equivalent/better ramp (--end_karras_ramp_early) without requiring a custom fork of k-diffusion crowsonkb/k-diffusion#23 (comment)
…ucting CompVisDenoiser with quantize=True. this means we don't need a custom fork of k-diffusion (except for for MPS fixes). only downside compared to my original approach is that we cannot set churn>0 (see crowsonkb/k-diffusion#23 (comment)), but we never used that. I reckon the ability to quantize sigma_hat will be added to mainline k-diffusion eventually (discussing here: crowsonkb/k-diffusion#23 (comment)), so think it's best to keep the k-diffusion branch free of bespoke changes (with the exception of MPS), to keep it easy to rebase onto mainline. remove ability to opt in/out of discretization, now that I've finished comparing them (crowsonkb/k-diffusion#23) -- the difference is barely perceptible but discretization is the better choice in theory.
I was thinking about something along the lines of the following: class BaseModelWrapper(nn.Module):
"""The base wrapper class for the k-diffusion model wrapper idiom. Model
wrappers should subclass this class and customize the behavior of the
wrapped model by implementing or overriding methods."""
def __init__(self, inner_model):
super().__init__()
self.inner_model = inner_model
def __dir__(self):
return list(set(super().__dir__() + dir(self.inner_model)))
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.inner_model, name)
def forward(self, *args, **kwargs):
return self.inner_model(*args, **kwargs) I'm not sure where the standard methods should go yet, on this base wrapper or separately implemented on the different denoiser wrappers, which would be changed to be subclasses of this. |
hey, sorry for slow reply. okay, so this wrapper creates the illusion that the wrapped instance is a subtype of inner_model. seems reasonable (we want the wrapped instance to be a substitute that can be used in all the same situations). as for where the standard methods (e.g. sigma_to_t()) should go... if it's a "general model wrapper" (i.e. nothing to do with diffusion, but perhaps with generic responsibilities like logging), then I wouldn't put sigma_to_t() this low. it it's a "diffusion model wrapper" (and I assume it is), then I think it makes sense to put sigma_to_t() this low if (and only if) that's something that every diffusion model needs. if there's a one-size-fits-all implementation of sigma_to_t() that can be put here, it can go here. generally, the decision of "should I put sigma_to_t() -- or at least an abstract interface for it -- this low", is answered by "what model type will the samplers integrate against?" |
It is yeah.
For a k-diffusion native model, sigma(t) = t, so the default implementation can simply return its input, same for t_to_sigma(). |
okay sure, then yes: let's put a default implementation (identity function) of sigma_to_t() and t_to_sigma() in the base diffusion model wrapper. |
Improves support for diffusion models with discrete time-steps (such as Stable Diffusion's DDIM).
I have some questions though, so this may need some iterating.
The user would invoke like so:
Implements the change to "Algorithm 2, line 5" described in the Elucidating paper arXiv:2206.00364 section C.3.4 "iDDPM practical considerations" practical challenge 3.
In other words we round sigmas to the nearest sigma supported by the DDIM.
For your convenience, here's the sigmas supported by Stable Diffusion DDIM:
https://gist.github.com/Birch-san/6cd1574e51871a5e2b88d59f0f3d4fd3
You may be wondering "okay, rounding sigma_hat solves challenge 3, but what about challenge 2".
There's an argument that solving challenge 3, solves challenge 2 for some situations.
When
gamma == 0
, rounding sigma_hat is equivalent to rounding sigma (which is what challenge 2 requires you to do for any outputs ofget_sigmas_karras()
).Problem here is the final sigma we'll receive, 0. we probably don't want to apply the same rounding rules to that… especially because we have a special-case predicated on 0. should that be predicated on uargmin instead, or perhaps on "have we reached the final sigma?"
edit: maybe the only reason they special-case 0 is because they want to avoid dividing by zero?
If we do care about satisfying challenge 2 in the
gamma > 0
situation, we'd want to round-to-nearest-sigma what comes out ofget_sigmas_karras()
. I happen to have made a torch snippet for runningargmin
on every element returned byget_sigmas_karras()
simultaneously:But again, not sure of what the implications are for the 0 it returns.
Anyway, maybe we can look at the outputs to decide. We'll try with keeping the 0 and without.
I tried to stress this to its limits by using as few steps as I could manage before it looked bad. All images are:
68673924
get_sigmas_karras()
noise schedule.Heun, 7 steps
Excluding 0 from
get_sigmas_karras()
The better-looking result was when I excluded the 0 returned by
get_sigmas_karras()
, in favour of ramping for 1 more step.Recall that SD's sigmas run from max =
14.6146
to min =0.0292
.Sigmas returned by
get_sigmas_karras()
:sample_heun
only iterates to n-1, so never touches the0.0292
.Time-step discretization enabled
Sigmas (up to n-1) after discretization:
Original k-diffusion behaviour (no discretization)
Not much perceptible difference. The discrete one defines the far sleeve better, but the other subtle differences it's hard for me to say which is the better generation.
Keeping 0 from
get_sigmas_karras()
So the paper didn't mention this, but the result is terrible at low step counts if you actually implement the 0 as they describe. Maybe this is just a problem for discrete time models?
Sigmas returned by
get_sigmas_karras()
:sample_heun
only iterates to n-1, so never touches the 0.Time-step discretization enabled
Sigmas (up to n-1) after discretization:
Original k-diffusion behaviour (no discretization)
Slightly more perceptible difference. The discrete one did better on the eyes and has slightly more clothing definition.
Conclusion
Removing the "concat 0" from
get_sigmas_karras()
seems to be hugely beneficial for small numbers of steps. This is not backed up by the literature. The reason I tried this was due to a misunderstanding. I saw that if I discretized the whole schedule, I'd end up with a repeated uargmin (… 0.0292, 0.0292]
). I removed the concat 0 to ensure I didn't end up producing duplicates. I didn't realize though that the sampler stops at n-1 so repeats aren't actually a problem. But it seems that for a different reason, the results are far better.Discretization of time-steps doesn't have the dramatic impact I was hoping for, but is probably still a sensible thing to do on the basis that the paper recommended it.
Heun, 50 steps, excluding 0
Let's do one more example, to 50 steps
Time-step discretization enabled
Original k-diffusion behaviour (no discretization)
Discretization seems to be more noticeable over 50 steps. The discretized image seems to have sharper hair and clothing, and highlights are brighter. Not sure I could say which is "better" though.
It's hard to compare images scrolling on GitHub; personally I flicked between these using QuickLook in the Finder.
If you know a better way to evaluate whether this is an improvement: I'm all ears! 👂