Skip to content

Commit

Permalink
Merge pull request #514 from pepijndevos/pv/progress
Browse files Browse the repository at this point in the history
progress bars for EnsembleProblem
  • Loading branch information
ChrisRackauckas authored Nov 9, 2023
2 parents 06d5c2c + a1370a0 commit 5797257
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 30 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand All @@ -51,6 +52,7 @@ SciMLBaseZygoteExt = "Zygote"
[compat]
ADTypes = "0.1.3, 0.2"
ArrayInterface = "6, 7"
ChainRules = "1.57.0"
ChainRulesCore = "1.16"
CommonSolve = "0.2.4"
ConstructionBase = "1"
Expand Down
139 changes: 109 additions & 30 deletions src/ensemble/basic_ensemble_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,64 @@ function merge_stats(us)
reduce(merge, st)
end

mutable struct AggregateLogger{T<:Logging.AbstractLogger} <: Logging.AbstractLogger
progress::Dict{Symbol, Float64}
done_counter::Int
total::Float64
print_time::Float64
lock::ReentrantLock
logger::T
end
AggregateLogger(logger::Logging.AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger)

function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...)
if convert(Logging.LogLevel, level) == Logging.LogLevel(-1) && haskey(kwargs, :progress)
pr = kwargs[:progress]
if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing)
try
if pr == "done"
pr = 1.0
l.done_counter += 1
end
len = length(l.progress)
if haskey(l.progress, id)
l.total += (pr-l.progress[id])/len
else
l.total = l.total*(len/(len+1)) + pr/(len+1)
len += 1
end
l.progress[id] = pr
# validation check (slow)
# tot = sum(values(l.progress))/length(l.progress)
# @show tot l.total l.total ≈ tot
curr_time = time()
if l.done_counter >= len
tot="done"
empty!(l.progress)
l.done_counter = 0
l.print_time = 0.0
elseif curr_time-l.print_time > 0.1
tot = l.total
l.print_time = curr_time
else
return
end
id=:total
message="Total"
kwargs=merge(values(kwargs), (progress=tot,))
finally
unlock(l.lock)
end
else
return
end
end
Logging.handle_message(l.logger, level, message, _module, group, id, file, line; kwargs...)
end
Logging.shouldlog(l::AggregateLogger, args...) = Logging.shouldlog(l.logger, args...)
Logging.min_enabled_level(l::AggregateLogger) = Logging.min_enabled_level(l.logger)
Logging.catch_exceptions(l::AggregateLogger) = Logging.catch_exceptions(l.logger)

function __solve(prob::AbstractEnsembleProblem,
alg::Union{AbstractDEAlgorithm, Nothing};
kwargs...)
Expand Down Expand Up @@ -59,51 +117,72 @@ end
function __solve(prob::AbstractEnsembleProblem,
alg::A,
ensemblealg::BasicEnsembleAlgorithm;
trajectories, batch_size = trajectories,
trajectories, batch_size = trajectories, progress_aggregate=true,
pmap_batch_size = batch_size ÷ 100 > 0 ? batch_size ÷ 100 : 1, kwargs...) where {A}
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end
logger = progress_aggregate ? AggregateLogger(Logging.current_logger()) : Logging.current_logger()

Logging.with_logger(logger) do
num_batches = trajectories ÷ batch_size
num_batches < 1 &&
error("trajectories ÷ batch_size cannot be less than 1, got $num_batches")
num_batches * batch_size != trajectories && (num_batches += 1)

converged::Bool = false
elapsed_time = @elapsed begin
i = 1
II = (batch_size * (i - 1) + 1):(batch_size * i)
if get(kwargs, :progress, false)
name = get(kwargs, :progress_name, "Ensemble")
for i in 1:trajectories
msg = "$name #$i"
Logging.@logmsg(Logging.LogLevel(-1), msg, _id=Symbol("SciMLBase_$i"), progress=0)
end
end


batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
if num_batches == 1 && prob.reduction === DEFAULT_REDUCTION
elapsed_time = @elapsed u = solve_batch(prob, alg, ensemblealg, 1:trajectories,
pmap_batch_size; kwargs...)
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, true, stats)
end

converged::Bool = false
elapsed_time = @elapsed begin
i = 1
II = (batch_size * (i - 1) + 1):(batch_size * i)

u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
u, converged = prob.reduction(u, batch_data, II)
for i in 2:num_batches
converged && break
if i == num_batches
II = (batch_size * (i - 1) + 1):trajectories
else
II = (batch_size * (i - 1) + 1):(batch_size * i)
end
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)

u = prob.u_init === nothing ? similar(batch_data, 0) : prob.u_init
u, converged = prob.reduction(u, batch_data, II)
for i in 2:num_batches
converged && break
if i == num_batches
II = (batch_size * (i - 1) + 1):trajectories
else
II = (batch_size * (i - 1) + 1):(batch_size * i)
end
batch_data = solve_batch(prob, alg, ensemblealg, II, pmap_batch_size; kwargs...)
u, converged = prob.reduction(u, batch_data, II)
end
end
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end
_u = tighten_container_eltype(u)
stats = merge_stats(_u)
return EnsembleSolution(_u, elapsed_time, converged, stats)
end

function batch_func(i, prob, alg; kwargs...)
iter = 1
_prob = prob.safetycopy ? deepcopy(prob.prob) : prob.prob
new_prob = prob.prob_func(_prob, i, iter)
rerun = true

progress = get(kwargs, :progress, false)
if progress
name = get(kwargs, :progress_name, "Ensemble")
progress_name = "$name #$i"
progress_id = Symbol("SciMLBase_$i")
kwargs = (kwargs..., progress_name=progress_name, progress_id=progress_id)
end
x = prob.output_func(solve(new_prob, alg; kwargs...), i)
if !(x isa Tuple)
rerun_warn()
Expand Down

0 comments on commit 5797257

Please sign in to comment.