From 423540a7f2a7b500c43669b8015760598cd72d35 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Mon, 18 Oct 2021 22:47:57 +0200 Subject: [PATCH 1/3] Add `ProfileRunner` --- Project.toml | 1 + src/FluxTraining.jl | 2 +- src/callbacks/runners/profiler.jl | 115 ++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 src/callbacks/runners/profiler.jl diff --git a/Project.toml b/Project.toml index 21a418a5d..85d08b177 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7" diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 0d78592d0..65d99c57f 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -25,7 +25,7 @@ using Zygote: Grads, gradient using ValueHistories using DataStructures: DefaultDict using PrettyTables - +using StructArrays # functional include("./functional/metrics.jl") diff --git a/src/callbacks/runners/profiler.jl b/src/callbacks/runners/profiler.jl new file mode 100644 index 000000000..3b388aacc --- /dev/null +++ b/src/callbacks/runners/profiler.jl @@ -0,0 +1,115 @@ + +# Data structures + +struct TimingBetween + phase + eventstart + eventend + timestart + timeend + duration +end + +TimingBetween(phase, es, ee, ts, te) = TimingBetween(phase, es, ee, ts, te, te-ts) + +struct TimingCallback + phase + cb + event + timestart + timeend + duration +end + +TimingCallback(phase, cb, e, ts, te) = TimingCallback(phase, cb, e, ts, te, te-ts) + + +# Runner + +""" + ProfileRunner() <: CallbackRunner + +A profiling callback runner that measures times for callback +handlers and times between events. This allows for granular +benchmarking of any training loop. + +## Examples + +To use, pass as `cbrunner` argument to `Learner`: + +```julia +cbrunner = ProfileRunner() +learner = Learner(model, data, opt, lossfn; cbrunner=cbrunner) +fit!(learner, 10) +``` + +After having trained, you can access the timings on fields: + +- `cbrunner.timesbetween`: Stores timings between events +- `cbrunner.timescallbacks`: Stores timings for callback handlers +""" +mutable struct ProfileRunner <: FluxTraining.CallbackRunner + timesbetween + timescallbacks + last +end + +ProfileRunner() = ProfileRunner( + StructArray{TimingBetween}( + phase = Phase[], + eventstart = Type{<:Event}[], + eventend = Type{<:Event}[], + timestart = Float64[], + timeend = Float64[], + duration = Float64[], + ), + StructArray{TimingCallback}( + phase = Phase[], + cb = FluxTraining.Callback[], + event = Type{<:Event}[], + timestart = Float64[], + timeend = Float64[], + duration = Float64[], + ), + nothing +) + + +function FluxTraining.handle(runner::ProfileRunner, event, phase, learner) + # add timing for inbetween + last = runner.last + if last !== nothing + t = Zygote.ignore(() -> Base.time()) + lastevent, lasttime, lastphase = last + if lastphase == phase + Zygote.ignore() do + timing = TimingBetween( + phase, typeof(lastevent), typeof(event), lasttime, t) + push!(runner.timesbetween, timing) + end + end + end + + # execute callback and add timing for it + idxs = Zygote.ignore() do + LightGraphs.topological_sort_by_dfs(learner.callbacks.graph) + end + for i in idxs + cb = learner.callbacks.cbs[i] + starttime = Zygote.ignore(() -> Base.time()) + FluxTraining._on(event, phase, cb, learner) + Zygote.ignore() do + timing = TimingCallback(phase, cb, typeof(event), starttime, Base.time()) + push!(runner.timescallbacks, timing) + end + end + + # update `last` so next between time can be measured + runner.last = (event, Zygote.ignore(() -> Base.time()), phase) + nothing +end + + +# Analysis + +# TODO From f98062cce610b004f63e1ae7998d3d9347203ec4 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Tue, 15 Feb 2022 13:23:49 +0100 Subject: [PATCH 2/3] Store types --- src/FluxTraining.jl | 1 + src/callbacks/runners/profiler.jl | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 65d99c57f..4df29ea43 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -37,6 +37,7 @@ include("./callbacks/events.jl") include("./callbacks/callback.jl") include("./callbacks/graph.jl") include("./callbacks/execution.jl") +include("./callbacks/runners/profiler.jl") # logging include("./callbacks/logging/Loggables.jl") diff --git a/src/callbacks/runners/profiler.jl b/src/callbacks/runners/profiler.jl index 3b388aacc..0fca95d23 100644 --- a/src/callbacks/runners/profiler.jl +++ b/src/callbacks/runners/profiler.jl @@ -56,7 +56,7 @@ end ProfileRunner() = ProfileRunner( StructArray{TimingBetween}( - phase = Phase[], + phase = Type{<:Phase}[], eventstart = Type{<:Event}[], eventend = Type{<:Event}[], timestart = Float64[], @@ -64,7 +64,7 @@ ProfileRunner() = ProfileRunner( duration = Float64[], ), StructArray{TimingCallback}( - phase = Phase[], + phase = Type{<:Phase}[], cb = FluxTraining.Callback[], event = Type{<:Event}[], timestart = Float64[], @@ -84,7 +84,7 @@ function FluxTraining.handle(runner::ProfileRunner, event, phase, learner) if lastphase == phase Zygote.ignore() do timing = TimingBetween( - phase, typeof(lastevent), typeof(event), lasttime, t) + typeof(phase), typeof(lastevent), typeof(event), lasttime, t) push!(runner.timesbetween, timing) end end @@ -99,7 +99,7 @@ function FluxTraining.handle(runner::ProfileRunner, event, phase, learner) starttime = Zygote.ignore(() -> Base.time()) FluxTraining._on(event, phase, cb, learner) Zygote.ignore() do - timing = TimingCallback(phase, cb, typeof(event), starttime, Base.time()) + timing = TimingCallback(typeof(phase), cb, typeof(event), starttime, Base.time()) push!(runner.timescallbacks, timing) end end From 82382b84229fda57df0487c98a7044760ad3208a Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Sun, 20 Feb 2022 10:54:37 +0100 Subject: [PATCH 3/3] Add first analysis --- Project.toml | 4 + src/FluxTraining.jl | 9 +- src/callbacks/runners/profiler.jl | 223 +++++++++++++++++++++++------- src/training.jl | 3 +- test/profilerunner.jl | 2 + test/runtests.jl | 1 + 6 files changed, 187 insertions(+), 55 deletions(-) create mode 100644 test/profilerunner.jl diff --git a/Project.toml b/Project.toml index fb79a8355..2ad2b15e2 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,9 @@ version = "0.2.3" [deps] Animations = "27a7e980-b3e6-11e9-2bcd-0b925532e340" BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4" +Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" @@ -27,6 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] Animations = "0.4" BSON = "0.2, 0.3" +DataFrames = "1.3" DataStructures = "0.18" EarlyStopping = "0.1, 0.2, 0.3" Flux = "0.11, 0.12" diff --git a/src/FluxTraining.jl b/src/FluxTraining.jl index 4df29ea43..04d8ef0ae 100644 --- a/src/FluxTraining.jl +++ b/src/FluxTraining.jl @@ -3,6 +3,10 @@ module FluxTraining using LightGraphs using BSON: @load, @save +using Colors: @colorant_str +using ColorSchemes: ColorScheme, colorschemes +using DataFrames: DataFrame, groupby, select, subset, combine +using DataFrames.PooledArrays: PooledArray using Flux using Flux: Params, onecold using Flux.Optimise: update! @@ -16,7 +20,7 @@ import OnlineStats using OnlineStats: EqualWeight, Mean, OnlineStat using Parameters using ProgressMeter: Progress, next! -using Statistics: mean +using Statistics: mean, median using UUIDs using Zygote using Animations @@ -105,5 +109,6 @@ export AbstractCallback, step!, onecycle, loadmodel, - savemodel + savemodel, + ProfileRunner end # module diff --git a/src/callbacks/runners/profiler.jl b/src/callbacks/runners/profiler.jl index 0fca95d23..c8d922831 100644 --- a/src/callbacks/runners/profiler.jl +++ b/src/callbacks/runners/profiler.jl @@ -2,26 +2,26 @@ # Data structures struct TimingBetween - phase - eventstart - eventend - timestart - timeend - duration + phase::Any + eventstart::Any + eventend::Any + timestart::Any + timeend::Any + duration::Any end -TimingBetween(phase, es, ee, ts, te) = TimingBetween(phase, es, ee, ts, te, te-ts) +TimingBetween(phase, es, ee, ts, te) = TimingBetween(phase, es, ee, ts, te, te - ts) struct TimingCallback - phase - cb - event - timestart - timeend - duration + phase::Any + cb::Any + event::Any + timestart::Any + timeend::Any + duration::Any end -TimingCallback(phase, cb, e, ts, te) = TimingCallback(phase, cb, e, ts, te, te-ts) +TimingCallback(phase, cb, e, ts, te) = TimingCallback(phase, cb, e, ts, te, te - ts) # Runner @@ -49,43 +49,65 @@ After having trained, you can access the timings on fields: - `cbrunner.timescallbacks`: Stores timings for callback handlers """ mutable struct ProfileRunner <: FluxTraining.CallbackRunner - timesbetween - timescallbacks - last + df_fit::DataFrame + df_cb::DataFrame + _last::Any end -ProfileRunner() = ProfileRunner( - StructArray{TimingBetween}( - phase = Type{<:Phase}[], - eventstart = Type{<:Event}[], - eventend = Type{<:Event}[], - timestart = Float64[], - timeend = Float64[], - duration = Float64[], - ), - StructArray{TimingCallback}( - phase = Type{<:Phase}[], - cb = FluxTraining.Callback[], - event = Type{<:Event}[], - timestart = Float64[], - timeend = Float64[], - duration = Float64[], - ), - nothing + +ProfileRunner() = ProfileRunner(_new_df_fit(), _new_df_cb(), nothing) + + +function Base.show(io::IO, runner::ProfileRunner) + print(io, "ProfileRunner(df_fit = ") + summary(io, runner.df_fit) + print(io, ", df_cb = ") + summary(io, runner.df_cb) + print(io, ")") +end + + +_new_df_fit() = DataFrame( + phase = PooledArray(Type{<:Phase}[], UInt8), + eventstart = PooledArray(Type{<:Event}[], UInt8), + eventend = PooledArray(Type{<:Event}[], UInt8), + timestart = Float64[], + timeend = Float64[], ) +_new_df_cb() = DataFrame( + phase = PooledArray(Type{<:Phase}[], UInt8), + event = PooledArray(Type{<:Event}[], UInt8), + callback = FluxTraining.Callback[], + timestart = Float64[], + timeend = Float64[], +) -function FluxTraining.handle(runner::ProfileRunner, event, phase, learner) + + +function FluxTraining.handle( + runner::ProfileRunner, + event::E, + phase::P, + learner, +) where {E<:Event,P<:Phase} # add timing for inbetween - last = runner.last + last = runner._last if last !== nothing - t = Zygote.ignore(() -> Base.time()) + timeend = Zygote.ignore(() -> Base.time()) lastevent, lasttime, lastphase = last - if lastphase == phase + if lastphase == P Zygote.ignore() do - timing = TimingBetween( - typeof(phase), typeof(lastevent), typeof(event), lasttime, t) - push!(runner.timesbetween, timing) + push!( + runner.df_fit, + (; + phase = P, + eventstart = lastevent, + eventend = E, + timestart = lasttime, + timeend = timeend, + ), + ) end end end @@ -95,21 +117,118 @@ function FluxTraining.handle(runner::ProfileRunner, event, phase, learner) LightGraphs.topological_sort_by_dfs(learner.callbacks.graph) end for i in idxs - cb = learner.callbacks.cbs[i] - starttime = Zygote.ignore(() -> Base.time()) + cb = learner.callbacks.cbs[i] + timestart = Zygote.ignore(() -> Base.time()) FluxTraining._on(event, phase, cb, learner) - Zygote.ignore() do - timing = TimingCallback(typeof(phase), cb, typeof(event), starttime, Base.time()) - push!(runner.timescallbacks, timing) - end + Zygote.ignore() do + timeend = Base.time() + push!( + runner.df_cb, + (; + phase = P, + event = E, + callback = cb, + timestart = timestart, + timeend = timeend, + ), + ) + end end # update `last` so next between time can be measured - runner.last = (event, Zygote.ignore(() -> Base.time()), phase) - nothing + runner._last = (E, Zygote.ignore(() -> Base.time()), P) + return nothing end -# Analysis +# ### Data transformations +# +# Get the data into a usable shape for further analysis. -# TODO +""" + getsteptimings(profilerunner[, Phase]) -> GroupedDataFrame + +Group the data of step timings by the events that they occur between. +""" +function getsteptimings(runner::ProfileRunner, P = AbstractTrainingPhase) + return groupby( + subset( + combine( + runner.df_fit, + [:timeend, :timestart] => ((e, s) -> e - s) => :duration, + :phase, + :eventstart, + :eventend, + ), + :phase => (ps -> ps .<: P), + :eventstart => (es -> ((es .!= EpochBegin) .& (es .!= EpochEnd))), + :eventend => (es -> ((es .!= EpochBegin) .& (es .!= EpochEnd))), + ), + [:eventstart, :eventend], + ) +end + +# ### Analysis and visualization +# +# Provide helpful analyses that show most important timings and help with +# benchmarking and identifying bottlenecks. + +""" + showsteptimings(profilerunner) + showsteptimings(io, profilerunner, P = AbstractTrainingPhase; metrics = [...]) + + +""" +function showsteptimings( + io::IO, + runner::ProfileRunner, + P = AbstractTrainingPhase; + metrics = [median, minimum, maximum], +) + gdf = getsteptimings(runner, P) + rownames = ["$(k.eventstart) => $(k.eventend)" for k in keys(gdf)] + rowdata = [metricfn(eventdf.duration .* 1000) for eventdf in gdf, metricfn in metrics] + pretty_table( + io, + rowdata, + header = (string.(metrics), repeat(["ms"], length(metrics))), + row_names = rownames, + row_name_column_title = "Event", + highlighters = _timinghighlighter(), + formatters = ft_printf("%5.3f"), + ) +end +showsteptimings(args...; kwargs...) = showsteptimings(stdout, args...; kwargs...) + + +# #### PrettyTables.jl utilities + +_timinghighlighter() = Highlighter( + (data, i, j) -> true, + function (h, data, i, j) + ext = extrema(data[:, j]) + ext = 0., ext[2] + return Crayon( + background = _cvtcolor( + get( + ColorScheme(range(colorant"black", colorant"darkorange4")), + data[i, j], + ext, + ), + ), + foreground = _cvtcolor( + get( + ColorScheme(range(colorant"gray", colorant"white")), + data[i, j], + ext, + ), + ), + ) + end, +) + +_cvtcolor(c::Color) = ( + round(Int, Colors.red(c) * 255), + round(Int, Colors.green(c) * 255), + round(Int, Colors.blue(c) * 255), +) diff --git a/src/training.jl b/src/training.jl index 411f3c338..fbcf3fcdd 100644 --- a/src/training.jl +++ b/src/training.jl @@ -47,8 +47,9 @@ end function step!(learner, phase::ValidationPhase, batch) xs, ys = batch - runstep(learner, phase, (;xs=xs, ys=ys)) do _, state + runstep(learner, phase, (;xs=xs, ys=ys)) do handle, state state.ŷs = learner.model(state.xs) + handle(LossBegin()) state.loss = learner.lossfn(state.ŷs, state.ys) end end diff --git a/test/profilerunner.jl b/test/profilerunner.jl new file mode 100644 index 000000000..1cd23f7f2 --- /dev/null +++ b/test/profilerunner.jl @@ -0,0 +1,2 @@ +@testset "ProfileRunner" begin +end diff --git a/test/runtests.jl b/test/runtests.jl index 4cdd31161..d8d9bb576 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,5 +15,6 @@ include("./imports.jl") include("./callbacks/garbagecollect.jl") include("./callbacks/sanitycheck.jl") include("./callbacks/earlystopping.jl") + include("./profilerunner.jl") include("./callbackutils.jl") end