From 1b2a74c863f68f3442a5bad180988932c105f520 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 12:16:21 +0200 Subject: [PATCH 1/4] fix perf issue loadmodel! --- src/loading.jl | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/loading.jl b/src/loading.jl index 9098828a8b..3f95ad03ef 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -1,9 +1,10 @@ -loadleaf!(dst, src, err) = dst -loadleaf!(dst::AbstractArray, src, err) = +loadleaf!(dst, src) = dst +loadleaf!(dst::AbstractArray, src) = error("Tried to copy $src into an array destination; this is not allowed.") -loadleaf!(dst, src::AbstractArray, err) = +loadleaf!(dst, src::AbstractArray) = error("Tried to copy an array to $dst; this is not allowed.") -function loadleaf!(dst::AbstractArray, src::Bool, err) + +function loadleaf!(dst::AbstractArray, src::Bool) if iszero(src) dst .= src else @@ -11,9 +12,12 @@ function loadleaf!(dst::AbstractArray, src::Bool, err) end return dst end -loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst : + +loadleaf!(dst::Bool, src::AbstractArray) = iszero(dst) ? dst : error("Cannot copy non-zero parameter to boolean parameter == true.") -function loadleaf!(dst::AbstractArray, src::AbstractArray, err) + +function loadleaf!(dst::AbstractArray, src::AbstractArray) + err = DimensionMismatch("Tried to load size $(size(src)) array into size $(size(dst))") (size(dst) == size(src)) || throw(err) copyto!(dst, src) end @@ -28,9 +32,6 @@ _tie_check(dst, src) = true _bool_tie_check(dst, src) = true -_filter_children(f, children::NamedTuple) = - NamedTuple(filter(kv -> f(kv[2]), pairs(children))) -_filter_children(f, children) = filter(f, children) """ loadmodel!(dst, src) @@ -81,21 +82,22 @@ however, attempting to copy a non-zero array to an inactive parameter will throw Likewise, copying a `src` value of `false` to any `dst` array is valid, but copying a `src` value of `true` will error. """ -function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) - ldsts = _filter_children(filter, functor(dst)[1]) - lsrcs = _filter_children(filter, functor(src)[1]) - (keys(ldsts) == keys(lsrcs)) || - throw(ArgumentError("Tried to load $src into $dst but the structures do not match.")) +function loadmodel!(dst, src; cache = Base.IdSet()) + ldsts = Functors.children(dst) + lsrcs = Functors.children(src) + kdsts = keys(ldsts) + ksrcs = keys(lsrcs) + (kdsts == ksrcs) || + throw(ArgumentError("Tried to load $ksrcs into $kdsts but the structures do not match.")) - err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.") foreach(ldsts, lsrcs) do ldst, lsrc if ldst in cache # we already loaded this parameter before _tie_check(ldst, lsrc) && return ldst elseif Functors.isleaf(ldst) # our first time loading this leaf push!(cache, ldst) - loadleaf!(ldst, lsrc, err) + loadleaf!(ldst, lsrc) else # this isn't a leaf - loadmodel!(ldst, lsrc; filter = filter, cache = cache) + loadmodel!(ldst, lsrc; cache = cache) end end From 851be1bed20bb22111c871eff40f2dcc3fe0521e Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 13:45:47 +0200 Subject: [PATCH 2/4] reinstate filter --- src/loading.jl | 15 ++++++++------- test.jl | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 7 deletions(-) create mode 100644 test.jl diff --git a/src/loading.jl b/src/loading.jl index 3f95ad03ef..848d2da8db 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -32,6 +32,9 @@ _tie_check(dst, src) = true _bool_tie_check(dst, src) = true +_filter_children(f, children::NamedTuple) = + NamedTuple(filter(kv -> f(kv[2]), pairs(children))) +_filter_children(f, children) = filter(f, children) """ loadmodel!(dst, src) @@ -82,13 +85,11 @@ however, attempting to copy a non-zero array to an inactive parameter will throw Likewise, copying a `src` value of `false` to any `dst` array is valid, but copying a `src` value of `true` will error. """ -function loadmodel!(dst, src; cache = Base.IdSet()) - ldsts = Functors.children(dst) - lsrcs = Functors.children(src) - kdsts = keys(ldsts) - ksrcs = keys(lsrcs) - (kdsts == ksrcs) || - throw(ArgumentError("Tried to load $ksrcs into $kdsts but the structures do not match.")) +function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) + ldsts = _filter_children(filter, Functors.children(dst)) + lsrcs = _filter_children(filter, Functors.children(src)) + (keys(ldsts) == keys(lsrcs)) || + throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match.")) foreach(ldsts, lsrcs) do ldst, lsrc if ldst in cache # we already loaded this parameter before diff --git a/test.jl b/test.jl new file mode 100644 index 0000000000..4e5e2b6ed6 --- /dev/null +++ b/test.jl @@ -0,0 +1,22 @@ +using Metalhead, Flux + +m1 = ResNet(18) +m2 = ResNet(18) +@time Flux.loadmodel!(m2, m1) # warmup +@time Flux.loadmodel!(m2, m1) +# 0.003388 seconds (23.01 k allocations: 2.157 MiB) # this PR + +## SAVE AND LOAD +using Functors + +function state(x) + if Functors.isleaf(x) + return x + else + return map(state, Functors.children(x)) + end +end + +s = state(m1); +@time Flux.loadmodel!(m2, s); # warmup +@time Flux.loadmodel!(m2, s); From 2943cfd3c8c3053c0a04567d1035016b5741ec20 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 13:46:13 +0200 Subject: [PATCH 3/4] cleanup --- test.jl | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 test.jl diff --git a/test.jl b/test.jl deleted file mode 100644 index 4e5e2b6ed6..0000000000 --- a/test.jl +++ /dev/null @@ -1,22 +0,0 @@ -using Metalhead, Flux - -m1 = ResNet(18) -m2 = ResNet(18) -@time Flux.loadmodel!(m2, m1) # warmup -@time Flux.loadmodel!(m2, m1) -# 0.003388 seconds (23.01 k allocations: 2.157 MiB) # this PR - -## SAVE AND LOAD -using Functors - -function state(x) - if Functors.isleaf(x) - return x - else - return map(state, Functors.children(x)) - end -end - -s = state(m1); -@time Flux.loadmodel!(m2, s); # warmup -@time Flux.loadmodel!(m2, s); From d1a5dbabe121dbc66804f839b87772975c4c7911 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 25 Apr 2023 13:47:33 +0200 Subject: [PATCH 4/4] cleanup --- src/loading.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/loading.jl b/src/loading.jl index 848d2da8db..5cdd129936 100644 --- a/src/loading.jl +++ b/src/loading.jl @@ -98,7 +98,7 @@ function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet()) push!(cache, ldst) loadleaf!(ldst, lsrc) else # this isn't a leaf - loadmodel!(ldst, lsrc; cache = cache) + loadmodel!(ldst, lsrc; filter, cache) end end