Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add training loop profiler #89

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ version = "0.2.4"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
ColorSchemes = "35d6a980-a343-548e-a6ea-1d62b119f2f4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
EarlyStopping = "792122b4-ca99-40de-a6bc-6742525f08b6"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand All @@ -19,13 +22,15 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
ValueHistories = "98cad3c8-aec3-5f06-8e41-884608649ab7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BSON = "0.2, 0.3"
DataFrames = "1.3"
DataStructures = "0.18"
EarlyStopping = "0.1, 0.2, 0.3"
Flux = "0.11, 0.12, 0.13"
Expand Down
11 changes: 9 additions & 2 deletions src/FluxTraining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ module FluxTraining

using Graphs
using BSON: @load, @save
using Colors: @colorant_str
using ColorSchemes: ColorScheme, colorschemes
using DataFrames: DataFrame, groupby, select, subset, combine
using DataFrames.PooledArrays: PooledArray
using Flux
using Flux: Params, onecold
using Flux.Optimise: update!
Expand All @@ -17,7 +21,7 @@ import OnlineStats
using OnlineStats: EqualWeight, Mean, OnlineStat
using Parameters
using ProgressMeter: Progress, next!
using Statistics: mean
using Statistics: mean, median
using UUIDs
using Zygote
using ParameterSchedulers
Expand All @@ -26,6 +30,7 @@ using Zygote: Grads, gradient
using ValueHistories
using DataStructures: DefaultDict
using PrettyTables
using StructArrays

# functional
include("./functional/metrics.jl")
Expand All @@ -37,6 +42,7 @@ include("./callbacks/events.jl")
include("./callbacks/callback.jl")
include("./callbacks/graph.jl")
include("./callbacks/execution.jl")
include("./callbacks/runners/profiler.jl")

# logging
include("./callbacks/logging/Loggables.jl")
Expand Down Expand Up @@ -105,5 +111,6 @@ export AbstractCallback,
step!,
onecycle,
loadmodel,
savemodel
savemodel,
ProfileRunner
end # module
234 changes: 234 additions & 0 deletions src/callbacks/runners/profiler.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@

# Data structures

struct TimingBetween
phase::Any
eventstart::Any
eventend::Any
timestart::Any
timeend::Any
duration::Any
end

TimingBetween(phase, es, ee, ts, te) = TimingBetween(phase, es, ee, ts, te, te - ts)

struct TimingCallback
phase::Any
cb::Any
event::Any
timestart::Any
timeend::Any
duration::Any
end

TimingCallback(phase, cb, e, ts, te) = TimingCallback(phase, cb, e, ts, te, te - ts)


# Runner

"""
ProfileRunner() <: CallbackRunner

A profiling callback runner that measures times for callback
handlers and times between events. This allows for granular
benchmarking of any training loop.

## Examples

To use, pass as `cbrunner` argument to `Learner`:

```julia
cbrunner = ProfileRunner()
learner = Learner(model, data, opt, lossfn; cbrunner=cbrunner)
fit!(learner, 10)
```

After having trained, you can access the timings on fields:

- `cbrunner.timesbetween`: Stores timings between events
- `cbrunner.timescallbacks`: Stores timings for callback handlers
"""
mutable struct ProfileRunner <: FluxTraining.CallbackRunner
df_fit::DataFrame
df_cb::DataFrame
_last::Any
end


ProfileRunner() = ProfileRunner(_new_df_fit(), _new_df_cb(), nothing)


function Base.show(io::IO, runner::ProfileRunner)
print(io, "ProfileRunner(df_fit = ")
summary(io, runner.df_fit)
print(io, ", df_cb = ")
summary(io, runner.df_cb)
print(io, ")")
end


_new_df_fit() = DataFrame(
phase = PooledArray(Type{<:Phase}[], UInt8),
eventstart = PooledArray(Type{<:Event}[], UInt8),
eventend = PooledArray(Type{<:Event}[], UInt8),
timestart = Float64[],
timeend = Float64[],
)

_new_df_cb() = DataFrame(
phase = PooledArray(Type{<:Phase}[], UInt8),
event = PooledArray(Type{<:Event}[], UInt8),
callback = FluxTraining.Callback[],
timestart = Float64[],
timeend = Float64[],
)



function FluxTraining.handle(
runner::ProfileRunner,
event::E,
phase::P,
learner,
) where {E<:Event,P<:Phase}
# add timing for inbetween
last = runner._last
if last !== nothing
timeend = Zygote.ignore(() -> Base.time())
lastevent, lasttime, lastphase = last
if lastphase == P
Zygote.ignore() do
push!(
runner.df_fit,
(;
phase = P,
eventstart = lastevent,
eventend = E,
timestart = lasttime,
timeend = timeend,
),
)
end
end
end

# execute callback and add timing for it
idxs = Zygote.ignore() do
LightGraphs.topological_sort_by_dfs(learner.callbacks.graph)
end
for i in idxs
cb = learner.callbacks.cbs[i]
timestart = Zygote.ignore(() -> Base.time())
FluxTraining._on(event, phase, cb, learner)
Zygote.ignore() do
timeend = Base.time()
push!(
runner.df_cb,
(;
phase = P,
event = E,
callback = cb,
timestart = timestart,
timeend = timeend,
),
)
end
end

# update `last` so next between time can be measured
runner._last = (E, Zygote.ignore(() -> Base.time()), P)
return nothing
end


# ### Data transformations
#
# Get the data into a usable shape for further analysis.

"""
getsteptimings(profilerunner[, Phase]) -> GroupedDataFrame

Group the data of step timings by the events that they occur between.
"""
function getsteptimings(runner::ProfileRunner, P = AbstractTrainingPhase)
return groupby(
subset(
combine(
runner.df_fit,
[:timeend, :timestart] => ((e, s) -> e - s) => :duration,
:phase,
:eventstart,
:eventend,
),
:phase => (ps -> ps .<: P),
:eventstart => (es -> ((es .!= EpochBegin) .& (es .!= EpochEnd))),
:eventend => (es -> ((es .!= EpochBegin) .& (es .!= EpochEnd))),
),
[:eventstart, :eventend],
)
end

# ### Analysis and visualization
#
# Provide helpful analyses that show most important timings and help with
# benchmarking and identifying bottlenecks.

"""
showsteptimings(profilerunner)
showsteptimings(io, profilerunner, P = AbstractTrainingPhase; metrics = [...])


"""
function showsteptimings(
io::IO,
runner::ProfileRunner,
P = AbstractTrainingPhase;
metrics = [median, minimum, maximum],
)
gdf = getsteptimings(runner, P)
rownames = ["$(k.eventstart) => $(k.eventend)" for k in keys(gdf)]
rowdata = [metricfn(eventdf.duration .* 1000) for eventdf in gdf, metricfn in metrics]
pretty_table(
io,
rowdata,
header = (string.(metrics), repeat(["ms"], length(metrics))),
row_names = rownames,
row_name_column_title = "Event",
highlighters = _timinghighlighter(),
formatters = ft_printf("%5.3f"),
)
end
showsteptimings(args...; kwargs...) = showsteptimings(stdout, args...; kwargs...)


# #### PrettyTables.jl utilities

_timinghighlighter() = Highlighter(
(data, i, j) -> true,
function (h, data, i, j)
ext = extrema(data[:, j])
ext = 0., ext[2]
return Crayon(
background = _cvtcolor(
get(
ColorScheme(range(colorant"black", colorant"darkorange4")),
data[i, j],
ext,
),
),
foreground = _cvtcolor(
get(
ColorScheme(range(colorant"gray", colorant"white")),
data[i, j],
ext,
),
),
)
end,
)

_cvtcolor(c::Color) = (
round(Int, Colors.red(c) * 255),
round(Int, Colors.green(c) * 255),
round(Int, Colors.blue(c) * 255),
)
3 changes: 2 additions & 1 deletion src/training.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,9 @@ end

function step!(learner, phase::ValidationPhase, batch)
xs, ys = batch
runstep(learner, phase, (;xs=xs, ys=ys)) do _, state
runstep(learner, phase, (;xs=xs, ys=ys)) do handle, state
state.ŷs = learner.model(state.xs)
handle(LossBegin())
state.loss = learner.lossfn(state.ŷs, state.ys)
end
end
Expand Down
2 changes: 2 additions & 0 deletions test/profilerunner.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
@testset "ProfileRunner" begin
end