-
Notifications
You must be signed in to change notification settings - Fork 220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zygote's compilation scales badly with the number of ~
statements
#1754
Comments
Thanks a lot for posting this! |
Do we have code that checks the compile times of models? How long has Zygote compilation been taking this long? |
Since Julia 1.6 afaik. |
As in, the 1.6 update caused it? Is the compilation faster on 1.5? |
@torfjelde probably can get the right answer, if it matters. There are a lot of packages and versions interacting so not sure it necessarily matches a particular Julia version. It might take just as long to find the exact combination of Julia, zygote, chainrules, Turing, dynamicppl, etc versions that caused it as it would to actually solve the problem. |
I don't think so. It's only my experience instead of any benchmarking/profiling (which we need!) I don't have a good answer to this question. |
Sorry to bump so soon @torfjelde but we have any sense of a timeline on this? Are we thinking a week, a month, etc.? We would help out but I fear this is a little too tightly connected to the PPL macros for us to contribute to. |
I'm sorry, I can't put a timeline on this right now. And I worry and suspect it's not particularly related to Turing, unfortunately 😕 There was one change we made in Turing that I worried could have caused it, but when I tried with an older version the performance was still horrible. The best way to identify what's wrong is to just run the test above for different versions of Turing and Zygote. |
As of right now I have the following results:
Columns with numeric values represents the number of This is on the following system: julia> versioninfo()
Julia Version 1.6.1
Commit 6aaedecc44 (2021-04-23 05:59 UTC)
Platform Info:
OS: Linux (x86_64-pc-linux-gnu)
CPU: Intel(R) Core(TM) i7-6850K CPU @ 3.60GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-11.0.1 (ORCJIT, broadwell) The sad news is that I can't really glean anything useful from the above 😕 Here it seems to be "fine" (taking 5mins to compile is not good, but it's much better than the reported numbers). Comments:
Script I'm runningusing Pkg; Pkg.activate(mktempdir())
TURING_VERSION = ENV["TURING_VERSION"]
ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]
@info "Trying to install Turing@$(TURING_VERSION) and Zygote@$(ZYGOTE_VERSION)"
Pkg.add(name="Turing", version=TURING_VERSION)
Pkg.add(name="Zygote", version=ZYGOTE_VERSION)
using Turing, Zygote
pkgversion(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]["version"]
@info "Installed Turing@$(pkgversion(Turing)) and Zygote@$(pkgversion(Zygote))"
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
t = @elapsed Turing.Core.gradient_logp(
adbackend,
vi[spl],
vi,
model,
spl
);
push!(results, t)
@info "Result" t
println(join(results, " | "))
if haskey(ENV, "OUTPUT_FILE")
open(ENV["OUTPUT_FILE"], "a") do io
write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
write(io, "\n")
end
end Script for Turing >= 0.21using Pkg
if any(Base.Fix1(haskey, ENV), ["TURING_VERSION", "ZYGOTE_VERSION"])
# In this case, we create a new env and install the corresponding package versions.
Pkg.activate(mktempdir())
if haskey(ENV, "TURING_VERSION")
TURING_VERSION = ENV["TURING_VERSION"]
@info "Trying to install Turing@$(TURING_VERSION)"
Pkg.add(name="Turing", version=TURING_VERSION)
end
if haskey(ENV, "ZYGOTE_VERSION")
ZYGOTE_VERSION = ENV["ZYGOTE_VERSION"]
@info "Trying to install Zygote@$(ZYGOTE_VERSION)"
Pkg.add(name="Zygote", version=ZYGOTE_VERSION)
end
end
using Turing, Zygote
using Turing: LogDensityProblems
if VERSION < v"1.6.2"
pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))[string(mod)][1]
else
pkginfo(mod) = Pkg.TOML.parsefile(joinpath(dirname(Pkg.project().path), "Manifest.toml"))["deps"][string(mod)][1]
end
pkgversion(mod) = pkginfo(mod)["version"]
pkghash(mod) = pkginfo(mod)["git-tree-sha1"]
@info "Installed Turing@$(pkgversion(Turing)) [#$(pkghash(Turing))] and Zygote@$(pkgversion(Zygote)) [#$(pkghash(Zygote))]"
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
results = []
num_tildes = 1
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
num_tildes = 5
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
num_tildes = 10
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
num_tildes = 15
@info "Running" num_tildes
Zygote.refresh()
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
t = @elapsed LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
push!(results, t)
@info "Result" t
println(join(results, " | "))
if haskey(ENV, "OUTPUT_FILE")
open(ENV["OUTPUT_FILE"], "a") do io
write(io, "| ", join([VERSION; pkgversion(Turing); pkgversion(Zygote); results;], " | "), " |")
write(io, "\n")
end
end Edits:
|
Thanks @torfjelde Is there any chance this stuff gets better on 1.7 with the latest zygote? I know Turing doesn't support it yet, but does the dynamicppl let you test on 1.7? |
I can but it will take a bit of work (currently using Turing functionality to compute the gradient). Do you know when you started experiencing these issues btw? Also, maybe some of the Zygote people have any idea what's going on here @mcabbott ? TL;DR: Compilation time of |
Possibly related: FluxML/Zygote.jl#1119 and FluxML/Zygote.jl#1126 EDIT: Seems like it. |
Wow, the EDIT: Even |
@torfjelde Am I reading that correctly that Julia 1.6.5 + Turing 0.19.3 + Zygote 0.6.33 brings it back to sanity? |
No, specifically you need the |
See if FluxML/Zygote.jl#1147 works as well as |
I can confirm that EDIT: preliminary experiments seem generate similar computing times. |
Will give it a go 👍 Btw, not sure if this is more useful information , but when I try |
I've never heard of Chopin preludes being used as a unit of measurement, but people should do that more often :)
Betting on no, as most performance-sensitive stuff is gated behind a rule. Maybe scalar- or control flow-heavy code, though the generated pullbacks for the latter are likely type unstable anyhow.
Fixed by FluxML/Zygote.jl#909 perhaps? That was in 0.6.27. |
Gave it a try; seems to do the trick! Benchmarks in table above. Don't know what effect it has on performance though, but seems like it would be worth it. |
Ah, probably! I'll give it a go. EDIT: Seems like indeed 0.6.27 improved things significantly 👍 |
Any progress on this issue by chance? Did a check with Julia 1.7 and the latest Turing, DynamicPPL, and Zygote and am still getting > 30 minute TTFG for my model with 20ish parameters. Just want to make sure that everyone knows those two Zygote issues linked did not fix things. |
@ToucheSir @Keno Any progress on this? |
Unfortunately I have nothing concrete to report, but I have been looking into this over the past couple of months. Any help on grokking compilation latency + Zygote's internals would be much appreciated. I can't speak for Keno, but to my knowledge working on this is not on anyone else's plate. |
@torfjelde I'm not able to repro your latest timings on 1.8.2 locally with the following reduced MWE: using Turing, Zygote
using Turing: LogDensityProblems
using SnoopCompileCore
# This helps a bit, ~4s
# @eval Turing.DynamicPPL begin
# ChainRulesCore.@non_differentiable is_flagged(::VarInfo, ::VarName, ::String)
# end
Turing.setadbackend(:zygote);
adbackend = Turing.Core.ZygoteAD();
spl = DynamicPPL.SampleFromPrior();
num_tildes = 5
# num_tildes = 10
expr = :(function $(Symbol(:demo, num_tildes))() end) |> Base.remove_linenums!;
mainbody = last(expr.args);
append!(mainbody.args, [:($(Symbol("x", j)) ~ Normal()) for j = 1:num_tildes]);
f = @eval $(DynamicPPL.model(:Main, LineNumberNode(1), expr, false))
model = f();
vi = DynamicPPL.VarInfo(model);
@info "starting eval"
ℓ = Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext());
ℓ_with_grad = LogDensityProblems.ADgradient(:Zygote, ℓ);
tinf = @snoopi_deep LogDensityProblems.logdensity_and_gradient(ℓ_with_grad, vi[spl]);
@info "done eval"
using SnoopCompile, ProfileView
@show tinf Results:
Versioninfo:
|
I had a crack at figuring out why inference times also scale so badly. All outputs of that exploration may be found in this gist. It turns out that we can capture a good chunk of the slowness and poor scaling with just To demonstrate, we should first examine Now let's see the output for the augmented primal function from So what is to be done? The first thing that comes to mind is to figure out Footnotes
|
|
Last I asked there were no plans to support |
Briefly going to comment on this to say--the solution to this issue is to use ReverseDiff or ForwardDiff.jl or (a few years down the line when it's mature) maybe some other autodiff solution like Enzyme.jl. Development on Zygote/IRTools and source-to-source AD in Julia (rather than LLVM) is effectively dead now. |
To be honest, ReverseDiff has other issues and ForwardDiff is not always an option. Actually, Zygote is developed much more actively (https://github.com/FluxML/Zygote.jl/commits/master) than ReverseDiff (https://github.com/JuliaDiff/ReverseDiff.jl/commits/master) or ForwardDiff (https://github.com/JuliaDiff/ForwardDiff.jl/commits/master) (there hasn't been any new release of master since the breaking change that downstream packages did - IMO correctly - reject in a non-breaking release was reapplied to the master branch; and nobody wants to deal with and possibly fix these still existing downstream issues, hence nobody is willing to tag any new release on the ForwardDiff master branch). |
Although this is true, it's not quite an apples-to-apples comparison because of what those commits are doing. The majority of Zygote work these days is maintenance work, and much of that is mandatory because some change in Julia internals broke the source-to-source AD bit. The rest are primarily filling in holes/edge cases/robustness issues in the existing rule system, which has many, many more of them than either Forward or ReverseDiff.
IMO they should be listed out, because my impression is that many of them are more tractable than the fundamental ones facing Zygote. Some offline experimentation suggests that even the infamous lack of GPU support could be addressed without a complete package rewrite. That said, my personal deal-breaker with ReverseDiff is not any technical issue—it's that there appears to be no appetite for anything other than the most urgent maintenance. I can absolutely appreciate why this is the case, but it is a little disheartening to treat ReverseDiff as some technological dead-end when other languages/libraries have managed to take similar ideas further. |
It would be great to have a ReverseDiff2. In my experience, it is very performant and made a good tradeoff between simplicity, generality and performance. There are some weakly-justified pushes to differentiate through everything. However, it is very hard to differentiate through everything, and I am not sure that one wants to do that. A performant, well-tested, and maintainable AD is what's needed. cc @willtebbutt |
ReverseDiff e.g. only supports differentiation of vectors and real numbers, has problems with wrapper types (e.g., JuliaDiff/ReverseDiff.jl#223), faces the general problem of arrays of tracked reals vs tracked arrays, and probably defines too many methods (JuliaDiff/ReverseDiff.jl#226). It has multiple correctness issues (JuliaDiff/ReverseDiff.jl#145, JuliaDiff/ReverseDiff.jl#168, JuliaDiff/ReverseDiff.jl#239, JuliaDiff/ReverseDiff.jl#233) and its ChainRules macro support has multiple bugs (eg JuliaDiff/ReverseDiff.jl#221). Some of these issues might require more changes, some of them might be more easily fixable - but IMO it's really not as good currently as some people seem to think and also would require time and effort that nobody seems to be willing to invest (as you can see from the recent commit history). The main intention of my comment was just: All Julia AD packages have problems, and I think one thing that contributed to the current situation was that people suggested to abandon certain AD packages as soon as a new promising alternative appeared - and only later it was realized that they also have their own set of problems and limitations. I think it would be better
|
any update on this? |
I don't think there is a good solution yet; this is a general issue with Zygote: FluxML/Zygote.jl#1119 We are rewriting Zygote/ReverseDiff. Hopefully the issue will be resolved when that is complete. |
ok thanks for the info. what i found to work reasonably well is wrapping all distributions in an arraydist. starting julia with for example (ignore the 4 element inputs and # Handling distributions with varying parameters using arraydist
dists = [
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_ea
InverseGamma(0.1, 2.0, 0.025,5.0, μσ = true), # z_eb
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_eg
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_eqs
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_em
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_epinf
InverseGamma(0.1, 2.0, 0.01, 3.0, μσ = true), # z_ew
Beta(0.5, 0.20, μσ = true), # crhoa
Beta(0.5, 0.20, μσ = true), # crhob
Beta(0.5, 0.20, μσ = true), # crhog
Beta(0.5, 0.20, μσ = true), # crhoqs
Beta(0.5, 0.20, μσ = true), # crhoms
Beta(0.5, 0.20, μσ = true), # crhopinf
Beta(0.5, 0.20, μσ = true), # crhow
Beta(0.5, 0.2, μσ = true), # cmap
Beta(0.5, 0.2, μσ = true), # cmaw
Normal(4.0, 1.5, 2.0, 15.0), # csadjcost
Normal(1.50,0.375, 0.25, 3.0), # csigma
Beta(0.7, 0.1, μσ = true), # chabb
Beta(0.5, 0.1, μσ = true), # cprobw
Normal(2.0, 0.75, 0.25, 10.0), # csigl
Beta(0.5, 0.10, μσ = true), # cprobp
Beta(0.5, 0.15, μσ = true), # cindw
Beta(0.5, 0.15, μσ = true), # cindp
Beta(0.5, 0.15, μσ = true), # czcap
Normal(1.25, 0.125, 1.0, 3.0), # cfc
Normal(1.5, 0.25, 1.0, 3.0), # crpi
Beta(0.75, 0.10, μσ = true), # crr
Normal(0.125, 0.05, 0.001, 0.5), # cry
Normal(0.125, 0.05, 0.001, 0.5), # crdy
Gamma(0.625, 0.1, 0.1, 2.0, μσ = true), # constepinf
Gamma(0.25, 0.1, 0.01, 2.0, μσ = true), # constebeta
Normal(0.0, 2.0, -10.0, 10.0), # constelab
Normal(0.4, 0.10, 0.1, 0.8), # ctrend
Normal(0.5, 0.25, 0.01, 2.0), # cgy
Normal(0.3, 0.05, 0.01, 1.0), # calfa
]
Turing.@model function model_loglikelihood_function(data, m, observables,fixed_parameters)
all_params ~ Turing.arraydist(dists)
z_ea, z_eb, z_eg, z_eqs, z_em, z_epinf, z_ew, crhoa, crhob, crhog, crhoqs, crhoms, crhopinf, crhow, cmap, cmaw, csadjcost, csigma, chabb, cprobw, csigl, cprobp, cindw, cindp, czcap, cfc, crpi, crr, cry, crdy, constepinf, constebeta, constelab, ctrend, cgy, calfa = all_params
...
end |
@thorek1 thanks for sharing this! |
The compilation time with
|
That's awesome, but should this issue be closed? It's specifically related to Zygote, no? |
It should have been closed as "not planned" since fixing it requires some fairly fundamental changes to Zygote (e.g. working with optimised IR) |
Oh I must just saw another one trying to estimate Smets & Wouters paper with NUTS in Julia... |
I'm not certain whether or not this can be considered a "Turing.jl-issue" or not, but I figured I would at least raise it as an issue here so people are aware.
The compilation time of Zygote scales quite badly with the number of
~
statements.TL;DR: it takes almost 5 minutes to compile a model with 14
~
statements. I don't have the result here, but at some point I tried one with 20~
statements, and it took a full ~23 mins to compile.Demo
using Turing, Zygote
Running the following snippet a couple of times we get a sense of the compilation times:
That is, it takes almost 5 minutes to compile a model with 14
~
statements. I don't have the result here, but at some point I tried one with 20~
statements, and it took a full ~23 mins to compile.Additional info
versioninfo()
The text was updated successfully, but these errors were encountered: