Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Interlace #41

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "LazyArrays"
uuid = "5078a376-72f3-5289-bfd5-ec5146d43c02"
version = "0.14.9"
version = "0.15"


[deps]
Expand Down
4 changes: 3 additions & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ end

export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum,
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, ⋆, @~, LazyArray
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, ⋆, @~, LazyArray,
Interlace, interlace


include("lazyapplying.jl")
Expand All @@ -71,5 +72,6 @@ include("lazyconcat.jl")
include("lazysetoperations.jl")
include("lazyoperations.jl")
include("lazymacro.jl")
include("interlace.jl")

end # module
49 changes: 49 additions & 0 deletions src/interlace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
struct Interlace{T, N, AA, INDS} <: AbstractArray{T,N}
arrays::AA
inds::INDS
end


_sortunion(inds...) = sort!(union(inds...))

Check warning on line 7 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L7

Added line #L7 was not covered by tests
function _sortunion(inds::Vararg{AbstractRange,N}) where N
all(isequal(N), map(step, inds)) || throw(ArgumentError("incompatible"))
sort([map(first, inds)...]) == OneTo(N) || throw(ArgumentError("incompatible"))
n = mapreduce(length, +, inds)
maximum(map(last, inds)) == n || throw(ArgumentError("incompatible lengths"))
OneTo(n)

Check warning on line 13 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L9-L13

Added lines #L9 - L13 were not covered by tests
end


function check_interlace_inds(a, inds)
map(length,a) == map(length,inds) || throw(ArgumentError("Lengths must be compatible"))
n = mapreduce(length, +, a)
_sortunion(inds...) == OneTo(n) || throw(ArgumentError("Every index must be mapped to"))

Check warning on line 20 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L18-L20

Added lines #L18 - L20 were not covered by tests
end

function Interlace(a::NTuple{M,AbstractVector{T}}, inds::NTuple{M,AbstractVector{Int}}) where {T,M}
check_interlace_inds(a, inds)
Interlace{T,1,typeof(a), typeof(inds)}(a, inds)

Check warning on line 25 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
end

length(A::Interlace) = sum(map(length,A.arrays))
size(A::Interlace, m) = sum(size.(A.arrays,m))
size(A::Interlace{<:Any,1}) = (size(A,1),)

Check warning on line 30 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L28-L30

Added lines #L28 - L30 were not covered by tests
function getindex(A::Interlace{<:Any,1}, k::Integer)
for (a,ind) in zip(A.arrays, A.inds)
κ = findfirst(isequal(k), ind)
isnothing(κ) || return a[something(κ)]

Check warning on line 34 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L32-L34

Added lines #L32 - L34 were not covered by tests
end
throw(BoundsError(A, k))

Check warning on line 36 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L36

Added line #L36 was not covered by tests
end

function copyto!(dest::AbstractVector, src::Interlace{<:Any,1})
for (a,ind) in zip(src.arrays, src.inds)
copyto!(view(dest, ind), a)

Check warning on line 41 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L40-L41

Added lines #L40 - L41 were not covered by tests
end
dest

Check warning on line 43 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L43

Added line #L43 was not covered by tests
end

Interlace(a::AbstractVector, b::AbstractVector) =

Check warning on line 46 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L46

Added line #L46 was not covered by tests
Interlace((a,b), (1:2:(2length(a)-1), 2:2:2length(b)))

interlace(a...) = Array(Interlace(a...))

Check warning on line 49 in src/interlace.jl

View check run for this annotation

Codecov / codecov/patch

src/interlace.jl#L49

Added line #L49 was not covered by tests
2 changes: 2 additions & 0 deletions src/lazybroadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
_BroadcastArray(instantiate(Broadcasted{S}(bc.f, _broadcast2broadcastarray(bc.args...))))

BroadcastArray(f, A, As...) = BroadcastArray(broadcasted(f, A, As...))
BroadcastArray{T,N}(f, A...) where {T,N} = BroadcastArray{T,N,typeof(f),typeof(A)}(f, A)

Check warning on line 32 in src/lazybroadcasting.jl

View check run for this annotation

Codecov / codecov/patch

src/lazybroadcasting.jl#L32

Added line #L32 was not covered by tests

BroadcastMatrix(f, A...) = BroadcastMatrix(broadcasted(f, A...))
BroadcastVector(f, A...) = BroadcastVector(broadcasted(f, A...))

Expand Down
30 changes: 15 additions & 15 deletions src/lazyconcat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@
broadcasted(::LazyArrayStyle, op, A::Vcat) =
Vcat(broadcast(x -> broadcast(op, x), A.args)...)

for Cat in (:Vcat, :Hcat)
for Cat in (:Vcat, :Hcat)
@eval begin
broadcasted(::LazyArrayStyle, op, A::$Cat, c::Number) =
$Cat(broadcast((x,y) -> broadcast(op, x, y), A.args, c)...)
Expand All @@ -298,7 +298,7 @@
broadcasted(::LazyArrayStyle, op, A::$Cat, c::Ref) =
$Cat(broadcast((x,y) -> broadcast(op, x, Ref(y)), A.args, c)...)
broadcasted(::LazyArrayStyle, op, c::Ref, A::$Cat) =
$Cat(broadcast((x,y) -> broadcast(op, Ref(x), y), c, A.args)...)
$Cat(broadcast((x,y) -> broadcast(op, Ref(x), y), c, A.args)...)
end
end

Expand All @@ -318,7 +318,7 @@
B_arrays = _vcat_getindex_eval(B,kr...) # evaluate B at same chunks as A
ApplyVector(vcat, broadcast((a,b) -> broadcast(op,a,b), A.args, B_arrays)...)
end

function broadcasted(::LazyArrayStyle, op, A::AbstractVector, B::Vcat{<:Any,1})
kr = _vcat_axes(axes.(B.args)...)
A_arrays = _vcat_getindex_eval(A,kr...)
Expand Down Expand Up @@ -440,20 +440,20 @@
@eval $op(V::Vcat) = $op($op.(V.args))
end

function in(x, V::Vcat)
function in(x, V::Vcat)
for a in V.args
in(x, a) && return true
end
false
end

_fill!(a, x) = fill!(a,x)
function _fill!(a::Number, x)
function _fill!(a::Number, x)
a == x || throw(ArgumentError("Cannot set $a to $x"))
a
end

function fill!(V::Union{Vcat,Hcat}, x)
function fill!(V::Union{Vcat,Hcat}, x)
for a in V.args
_fill!(a, x)
end
Expand Down Expand Up @@ -546,8 +546,8 @@
convert(promote_type(T,V), dot(view(a,1:m), view(b,1:m)))
end

dot(a::CachedArray, b::AbstractArray) = materialize(Dot(a,b))
dot(a::LazyArray, b::AbstractArray) = materialize(Dot(a,b))
dot(a::CachedArray, b::AbstractArray) = materialize(Dot(a,b))
dot(a::LazyArray, b::AbstractArray) = materialize(Dot(a,b))


###
Expand All @@ -571,9 +571,9 @@
# a row-slice of an Hcat is equivalent to a Vcat
sublayout(::ApplyLayout{typeof(hcat)}, ::Type{<:Tuple{Number,AbstractVector}}) = ApplyLayout{typeof(vcat)}()

arguments(::ApplyLayout{typeof(vcat)}, V::SubArray{<:Any,2,<:Any,<:Tuple{<:Slice,<:Any}}) =
arguments(::ApplyLayout{typeof(vcat)}, V::SubArray{<:Any,2,<:Any,<:Tuple{<:Slice,<:Any}}) =
view.(arguments(parent(V)), Ref(:), Ref(parentindices(V)[2]))
arguments(::ApplyLayout{typeof(hcat)}, V::SubArray{<:Any,2,<:Any,<:Tuple{<:Any,<:Slice}}) =
arguments(::ApplyLayout{typeof(hcat)}, V::SubArray{<:Any,2,<:Any,<:Tuple{<:Any,<:Slice}}) =
view.(arguments(parent(V)), Ref(parentindices(V)[1]), Ref(:))

copyto!(dest::AbstractArray{T,N}, src::SubArray{T,N,<:Vcat{T,N}}) where {T,N} =
Expand Down Expand Up @@ -638,7 +638,7 @@
end
ret
end

function sub_materialize(::ApplyLayout{typeof(hcat)}, V)
ret = similar(V)
n = 0
Expand All @@ -652,10 +652,10 @@
end

# temporarily allocate. In the future, we add a loop over arguments
materialize!(M::MatMulMatAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}}) =
materialize!(M::MatMulMatAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}}) =
materialize!(MulAdd(M.α,M.A,Array(M.B),M.β,M.C))
materialize!(M::MatMulVecAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}}) =

Check warning on line 657 in src/lazyconcat.jl

View check run for this annotation

Codecov / codecov/patch

src/lazyconcat.jl#L657

Added line #L657 was not covered by tests
materialize!(MulAdd(M.α,M.A,Array(M.B),M.β,M.C))
materialize!(M::MatMulVecAdd{<:AbstractColumnMajor,<:ApplyLayout{typeof(vcat)}}) =
materialize!(MulAdd(M.α,M.A,Array(M.B),M.β,M.C))

sublayout(::PaddedLayout{L}, ::Type{I}) where {L,I<:Tuple{AbstractUnitRange}} =
PaddedLayout{typeof(sublayout(L(), I))}()
Expand Down Expand Up @@ -702,4 +702,4 @@
κ -= n
end
throw(BoundsError(f, (k,j)))
end
end
9 changes: 9 additions & 0 deletions test/interlacetests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using LazyArrays, Test

@testset "Interlace" begin
@test_throws ArgumentError Interlace(1:5, 10:15)
@test_throws ArgumentError Interlace(1:5, 10:12)
@test eltype(Interlace(1:5, 10:13)) == Int
@test Interlace(1:5, 10:13) == interlace(1:5,10:13) == [1,10,2,11,3,12,4,13,5]
@test Interlace(1:5, 10:14) == interlace(1:5,10:14) == [1,10,2,11,3,12,4,13,5,14]
end