From 058a25b7f7c88e4e0c43421e0e8e091c9f4954d7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 11 Apr 2024 11:30:09 -0400 Subject: [PATCH] use 5-arg copyto --- src/destructure.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/destructure.jl b/src/destructure.jl index ccc9add..baea015 100644 --- a/src/destructure.jl +++ b/src/destructure.jl @@ -152,14 +152,18 @@ end function _rebuild!(x, off, flat::AbstractVector, len = length(flat); walk = _Trainable_biwalk(), kw...) len == length(flat) || throw(DimensionMismatch("Rebuild expected a vector of length $len, got $(length(flat))")) fmap(x, off; exclude = isnumeric, walk, kw...) do y, o - copyto!(y, _getat(y, o, flat, view)) + # copyto!(y, _getat_view(y, o, flat)) + copyto!(y, 1, flat, o+1, length(y)) end x end -_getat(y::Number, o::Int, flat::AbstractVector, _...) = ProjectTo(y)(flat[o + 1]) -_getat(y::AbstractArray, o::Int, flat::AbstractVector, get=getindex) = - ProjectTo(y)(reshape(get(flat, o .+ (1:length(y))), axes(y))) # ProjectTo is just correcting eltypes +_getat(y::Number, o::Int, flat::AbstractVector) = ProjectTo(y)(flat[o + 1]) +_getat(y::AbstractArray, o::Int, flat::AbstractVector) = + ProjectTo(y)(reshape(flat[o .+ (1:length(y))], axes(y))) # ProjectTo is just correcting eltypes + +# _getat_view(y::AbstractArray, o::Int, flat::AbstractVector) = +# view(flat, o .+ (1:length(y))) struct _Trainable_biwalk <: AbstractWalk end