Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix perf issue loadmodel! #2241

Merged
merged 4 commits into from
Apr 25, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions src/loading.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
loadleaf!(dst, src, err) = dst
loadleaf!(dst::AbstractArray, src, err) =
loadleaf!(dst, src) = dst
loadleaf!(dst::AbstractArray, src) =
error("Tried to copy $src into an array destination; this is not allowed.")
loadleaf!(dst, src::AbstractArray, err) =
loadleaf!(dst, src::AbstractArray) =
error("Tried to copy an array to $dst; this is not allowed.")
function loadleaf!(dst::AbstractArray, src::Bool, err)

function loadleaf!(dst::AbstractArray, src::Bool)
if iszero(src)
dst .= src
else
error("Cannot copy boolean parameter == true to non-zero parameter.")
end
return dst
end
loadleaf!(dst::Bool, src::AbstractArray, err) = iszero(dst) ? dst :

loadleaf!(dst::Bool, src::AbstractArray) = iszero(dst) ? dst :
error("Cannot copy non-zero parameter to boolean parameter == true.")
function loadleaf!(dst::AbstractArray, src::AbstractArray, err)

function loadleaf!(dst::AbstractArray, src::AbstractArray)
err = DimensionMismatch("Tried to load size $(size(src)) array into size $(size(dst))")
(size(dst) == size(src)) || throw(err)
copyto!(dst, src)
end
Expand Down Expand Up @@ -82,20 +86,19 @@ Likewise, copying a `src` value of `false` to any `dst` array is valid,
but copying a `src` value of `true` will error.
"""
function loadmodel!(dst, src; filter = _ -> true, cache = Base.IdSet())
ldsts = _filter_children(filter, functor(dst)[1])
lsrcs = _filter_children(filter, functor(src)[1])
CarloLucibello marked this conversation as resolved.
Show resolved Hide resolved
ldsts = _filter_children(filter, Functors.children(dst))
lsrcs = _filter_children(filter, Functors.children(src))
(keys(ldsts) == keys(lsrcs)) ||
throw(ArgumentError("Tried to load $src into $dst but the structures do not match."))
throw(ArgumentError("Tried to load $(keys(lsrcs)) into $(keys(ldsts)) but the structures do not match."))

err = DimensionMismatch("Tried to load $src into $dst but the parameter sizes do not match.")
foreach(ldsts, lsrcs) do ldst, lsrc
if ldst in cache # we already loaded this parameter before
_tie_check(ldst, lsrc) && return ldst
elseif Functors.isleaf(ldst) # our first time loading this leaf
push!(cache, ldst)
loadleaf!(ldst, lsrc, err)
loadleaf!(ldst, lsrc)
else # this isn't a leaf
loadmodel!(ldst, lsrc; filter = filter, cache = cache)
loadmodel!(ldst, lsrc; filter, cache)
end
end

Expand Down