diff --git a/src/ensemble/basic_ensemble_solve.jl b/src/ensemble/basic_ensemble_solve.jl index 0578ba251..adf3003c4 100644 --- a/src/ensemble/basic_ensemble_solve.jl +++ b/src/ensemble/basic_ensemble_solve.jl @@ -26,28 +26,45 @@ struct EnsembleSerial <: BasicEnsembleAlgorithm end mutable struct AggregateLogger{T<:AbstractLogger} <: AbstractLogger progress::Dict{Symbol, Float64} + done_counter::Int total::Float64 + print_time::Float64 lock::ReentrantLock logger::T end -AggregateLogger(logger::AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(), 0.0, ReentrantLock(), logger) +AggregateLogger(logger::AbstractLogger) = AggregateLogger(Dict{Symbol, Float64}(),0 , 0.0, 0.0, ReentrantLock(), logger) function Logging.handle_message(l::AggregateLogger, level, message, _module, group, id, file, line; kwargs...) if convert(LogLevel, level) == LogLevel(-1) && haskey(kwargs, :progress) pr = kwargs[:progress] if trylock(l.lock) || (pr == "done" && lock(l.lock)===nothing) try - if pr isa Number - l.progress[id] = pr - elseif pr == "done" - l.progress[id] = 1.0 + if pr == "done" + pr = 1.0 + l.done_counter += 1 end - tot = sum(values(l.progress))/length(l.progress) - tot < 1.0 && isapprox(tot, l.total; atol=0.001) && return # less than 0.1% change - l.total = tot - if tot>=1.0 + len = length(l.progress) + if haskey(l.progress, id) + l.total += (pr-l.progress[id])/len + else + l.total = l.total*(len/(len+1)) + pr/(len+1) + len += 1 + end + l.progress[id] = pr + # validation check (slow) + # tot = sum(values(l.progress))/length(l.progress) + # @show tot l.total l.total ≈ tot + curr_time = time() + if l.done_counter >= len tot="done" empty!(l.progress) + l.done_counter = 0 + l.print_time = 0.0 + elseif curr_time-l.print_time > 0.1 + tot = l.total + l.print_time = curr_time + else + return end id=:total message="Total"