From 1839abc76d80a8327f9f7f956bdb7fefcfe4764b Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 2 Mar 2022 15:20:00 +0100 Subject: [PATCH 1/3] show status time --- src/training.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/training.jl b/src/training.jl index b10e2b72..11c63c3d 100644 --- a/src/training.jl +++ b/src/training.jl @@ -223,6 +223,7 @@ function learning_step!(env::Env, handler) Handlers.updates_started(handler, status) dlosses, dttrain = @timed batch_updates!(trainer, nbatches) status, dtloss = @timed learning_status(trainer) + @show dtloss Handlers.updates_finished(handler, status) tloss += dtloss ttrain += dttrain From 39232c3b82ce83def6e98a278e18ad60ffc6edfd Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 2 Mar 2022 15:20:09 +0100 Subject: [PATCH 2/3] single for loop --- src/learning.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/learning.jl b/src/learning.jl index 1702ca2d..0d5c20b5 100644 --- a/src/learning.jl +++ b/src/learning.jl @@ -154,6 +154,7 @@ end function learning_status(tr::Trainer, samples) # As done now, this is slighly inefficient as we solve the # same neural network inference problem twice + samples = Network.convert_input_tuple(tr.network, samples) W, X, A, P, V = samples regws = Network.regularized_params(tr.network) Ls = losses(tr.network, regws, tr.params, tr.Wmean, tr.Hp, samples) @@ -167,11 +168,12 @@ end function learning_status(tr::Trainer) batchsize = min(tr.params.loss_computation_batch_size, num_samples(tr)) batches = Flux.Data.DataLoader(tr.data; batchsize, partial=true) - reports = map(batches) do batch - batch = Network.convert_input_tuple(tr.network, batch) - return learning_status(tr, batch) + reports = [] + ws = [] + for batch in batches + push!(reports, learning_status(tr, batch)) + push!(ws, sum(batch.W)) end - ws = [sum(batch.W) for batch in batches] return mean_learning_status(reports, ws) end From 4de5a261630c0c5f968b68f3e73dfe3a42e98884 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 2 Mar 2022 15:20:00 +0100 Subject: [PATCH 3/3] Revert "show status time" This reverts commit 1839abc76d80a8327f9f7f956bdb7fefcfe4764b. --- src/training.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/training.jl b/src/training.jl index 11c63c3d..b10e2b72 100644 --- a/src/training.jl +++ b/src/training.jl @@ -223,7 +223,6 @@ function learning_step!(env::Env, handler) Handlers.updates_started(handler, status) dlosses, dttrain = @timed batch_updates!(trainer, nbatches) status, dtloss = @timed learning_status(trainer) - @show dtloss Handlers.updates_finished(handler, status) tloss += dtloss ttrain += dttrain