Skip to content

Commit

Permalink
Support non-number vcat
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Dec 7, 2024
1 parent 97361a4 commit 92585c9
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 11 deletions.
36 changes: 25 additions & 11 deletions src/lazyconcat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,35 @@ Vcat() = Vcat{Any}()
@inline function applied_instantiate(::typeof(vcat), args...)
iargs = map(instantiate, args)
if !isempty(iargs)
m = size(iargs[1],2)
m = _vcat_size(iargs[1],2)
for k=2:length(iargs)
size(iargs[k],2) == m || throw(ArgumentError("number of columns of each array must match (got $(map(x->size(x,2), args)))"))
_vcat_size(iargs[k],2) == m || throw(ArgumentError("number of columns of each array must match (got $(map(x->_vcat_size(x,2), args)))"))
end
end
vcat, iargs
end

_vcat_axes(a, k) = Base.OneTo(1)
_vcat_axes(a::AbstractArray, k) = axes(a, k)
_vcat_size(a, k) = 1
_vcat_size(a::AbstractArray, k) = size(a, k)
_vcat_ndims(a) = 1
_vcat_ndims(a::AbstractArray) = ndims(a)
_vcat_eltype(a) = typeof(a)
_vcat_eltype(a::AbstractArray) = eltype(a)
_vcat_length(a) = 1
_vcat_length(a::AbstractArray) = length(a)
_vcat_getindex(a, k...) = a
_vcat_getindex(a::AbstractArray, k...) = a[k...]


@inline applied_eltype(::typeof(vcat)) = Any
@inline applied_eltype(::typeof(vcat), args...) = promote_type(map(eltype, args)...)
@inline applied_ndims(::typeof(vcat), args...) = max(1,maximum(map(ndims,args)))
@inline applied_eltype(::typeof(vcat), args...) = promote_type(map(_vcat_eltype, args)...)
@inline applied_ndims(::typeof(vcat), args...) = max(1,maximum(map(_vcat_ndims,args)))
@inline applied_ndims(::typeof(vcat)) = 1
@inline axes(f::Vcat{<:Any,1,Tuple{}}) = (OneTo(0),)
@inline axes(f::Vcat{<:Any,1}) = tuple(oneto(+(map(length,f.args)...)))
@inline axes(f::Vcat{<:Any,2}) = (oneto(+(map(a -> size(a,1), f.args)...)), axes(f.args[1],2))
@inline axes(f::Vcat{<:Any,1}) = tuple(oneto(+(map(_vcat_length,f.args)...)))
@inline axes(f::Vcat{<:Any,2}) = (oneto(+(map(a -> _vcat_size(a,1), f.args)...)), _vcat_axes(f.args[1],2))
@inline size(f::Vcat) = map(length, axes(f))


Expand All @@ -57,17 +71,17 @@ end
f, idx::Tuple{Integer}, A, args...)
k, = idx
T = eltype(f)
n = length(A)
k n && return convert(T, A[k])::T
n = _vcat_length(A)
k n && return convert(T, _vcat_getindex(A,k))::T
vcat_getindex_recursive(f, (k - n, ), args...)
end

@propagate_inbounds @inline function vcat_getindex_recursive(
f, idx::Tuple{Integer,Integer}, A, args...)
k, j = idx
T = eltype(f)
n = size(A, 1)
k n && return convert(T, A[k, j])::T
n = _vcat_size(A, 1)
k n && return convert(T, _vcat_getindex(A, k, j))::T
vcat_getindex_recursive(f, (k - n, j), args...)
end

Expand Down Expand Up @@ -912,7 +926,7 @@ _replace_in_print_matrix(_, k, j, s) = s
function layout_replace_in_print_matrix(LAY::ApplyLayout{typeof(vcat)}, f::AbstractVecOrMat, k, j, s)
κ = k
for A in arguments(LAY, f)
n = size(A,1)
n = _vcat_size(A,1)
κ n && return _replace_in_print_matrix(A, κ, j, s)
κ -= n
end
Expand Down
6 changes: 6 additions & 0 deletions test/concattests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,12 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, materialize!, call, paddeddat
A = Vcat([1 2 3], [4 5 6])
@test A[2:-1:1,1:-1:1] == [4; 1 ;;]
end

@testset "general types" begin
@test Vcat("hi", "bye") == ["hi", "bye"]
@test Vcat(["hi" "bye"], [2 3]) == ["hi" "bye"; 2 3]
@test Vcat("hi", [2;;]) == ["hi"; 2 ;;]
end
end

end # module

0 comments on commit 92585c9

Please sign in to comment.