Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [BlockSparseArrays] Change behavior of non-blocked slicing #1484

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
@eval module $(gensym())
using Compat: Returns
using Test: @test, @testset, @test_broken
using BlockArrays: Block, blocksize
using BlockArrays: Block, blocklength, blocksize
using NDTensors.BlockSparseArrays: BlockSparseArray, block_nstored
using NDTensors.GradedAxes: GradedAxes, GradedUnitRange, UnitRangeDual, dual, gradedrange
using NDTensors.LabelledNumbers: label
Expand Down Expand Up @@ -50,14 +50,12 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})

b = a[2:3, 2:3, 2:3, 2:3]
@test size(b) == (2, 2, 2, 2)
@test blocksize(b) == (2, 2, 2, 2)
@test nstored(b) == 2
@test block_nstored(b) == 2
@test blocksize(b) == (1, 1, 1, 1)
@test nstored(b) == length(b)
@test block_nstored(b) == blocklength(b)
for i in 1:ndims(a)
@test axes(b, i) isa GradedUnitRange
@test axes(b, i) isa Base.OneTo{Int}
end
@test label(axes(b, 1)[Block(1)]) == U1(0)
@test label(axes(b, 1)[Block(2)]) == U1(1)
@test Array(a) isa Array{elt}
@test Array(a) == a
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using ArrayLayouts: LayoutArray
using BlockArrays: blockisequal
using ..GradedAxes: blocked_getindex
using LinearAlgebra: Adjoint, Transpose
using ..SparseArrayInterface:
SparseArrayInterface,
Expand All @@ -26,23 +27,42 @@ end
# This is type piracy, try to avoid this, maybe requires defining `map`.
## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2)

function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
# Work around issue that:
# ```julia
# julia> using BlockArrays: blocks
#
# julia> blocks(randn(2, 2))[1, 1]
# 2×2 view(::Matrix{Float64}, BlockSlice(Block(1),Base.OneTo(2)), BlockSlice(Block(1),Base.OneTo(2))) with eltype Float64:
# 0.0534014 -1.1738
# -0.649799 0.128661
# ```
# TODO: Raise an issue with BlockArrays.jl.
function blocks_getindex(a::AbstractArray{<:Any,N}, index::Vararg{Integer,N}) where {N}
return a[index...]
end
function blocks_getindex(
a::BlocksView{<:Any,N,<:Any,<:Array{<:Any,N}}, index::Vararg{Integer,N}
) where {N}
return a.array
end

function blocked_map!(f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray})
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
BI_dest = blockindexrange(a_dest, I)
BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs)
# TODO: Investigate why this doesn't work:
# block_dest = @view a_dest[_block(BI_dest)]
block_dest = blocks(a_dest)[Int.(Tuple(_block(BI_dest)))...]
# TODO: Use `blocks_getindex`.
block_dest = blocks_getindex(blocks(a_dest), Int.(Tuple(_block(BI_dest)))...)
# TODO: Investigate why this doesn't work:
# block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs))
block_srcs = ntuple(length(a_srcs)) do i
return blocks(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...]
# TODO: Use `blocks_getindex`.
return blocks_getindex(blocks(a_srcs[i]), Int.(Tuple(_block(BI_srcs[i])))...)
end
subblock_dest = @view block_dest[BI_dest.indices...]
subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs))
# TODO: Use `map!!` to handle immutable blocks.
# TODO: Use `map!!` to handle immutable blocks, such as FillArrays.
map!(f, subblock_dest, subblock_srcs...)
# Replace the entire block, handles initializing new blocks
# or if blocks are immutable.
Expand All @@ -51,6 +71,20 @@ function SparseArrayInterface.sparse_map!(
return a_dest
end

# Convert a non-block SubArray to a blocked subarray
# using the blocking of the underlying array.
to_blocked(a::AbstractArray) = a
function to_blocked(a::SubArray)
# Returns a `BlockedSubArray`.
return blocked_view(parent(a), parentindices(a)...)
end

function SparseArrayInterface.sparse_map!(
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
)
return blocked_map!(f, to_blocked.((a_dest, a_srcs...))...)
end

# TODO: Implement this.
# function SparseArrayInterface.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray})
# end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ using SplitApplyCombine: groupcount

using Adapt: Adapt, WrappedArray

const WrappedAbstractBlockSparseArray{T,N} = WrappedArray{
T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}
const WrappedAbstractBlockSparseArray{T,N} = Union{
WrappedArray{T,N,AbstractBlockSparseArray,AbstractBlockSparseArray{T,N}},
BlockedSubArray{T,N,<:AbstractBlockSparseArray{T,N}},
}

# TODO: Rename `AnyBlockSparseArray`.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using BlockArrays:
AbstractBlockArray,
AbstractBlockVector,
Block,
BlockedUnitRange,
Expand Down Expand Up @@ -265,12 +266,13 @@ end

# Represents the array of arrays of a `SubArray`
# wrapping a block spare array, i.e. `blocks(array)` where `a` is a `SubArray`.
struct SparseSubArrayBlocks{T,N,Array<:SubArray{T,N}} <: AbstractSparseArray{T,N}
# TODO: Define `blockstype`.
struct SparseSubArrayBlocks{T,N,Array<:AbstractArray{T,N}} <: AbstractSparseArray{T,N}
array::Array
end
# TODO: Define this as `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
function blockrange(a::SparseSubArrayBlocks)
blockranges = blockrange.(axes(parent(a.array)), a.array.indices)
blockranges = blockrange.(axes(parent(a.array)), parentindices(a.array))
return map(blockrange -> Int.(blockrange), blockranges)
end
function Base.axes(a::SparseSubArrayBlocks)
Expand All @@ -284,16 +286,17 @@ function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where
parent_block = parent_blocks[I...]
# TODO: Define this using `blockrange(a::AbstractArray, indices::Tuple{Vararg{AbstractUnitRange}})`.
block = Block(ntuple(i -> blockrange(a)[i][I[i]], ndims(a)))
return @view parent_block[blockindices(parent(a.array), block, a.array.indices)...]
return @view parent_block[blockindices(parent(a.array), block, parentindices(a.array))...]
end
# TODO: This should be handled by generic `AbstractSparseArray` code.
function Base.getindex(a::SparseSubArrayBlocks{<:Any,N}, I::CartesianIndex{N}) where {N}
return a[Tuple(I)...]
end
function Base.setindex!(a::SparseSubArrayBlocks{<:Any,N}, value, I::Vararg{Int,N}) where {N}
parent_blocks = view(blocks(parent(a.array)), axes(a)...)
return parent_blocks[I...][blockindices(parent(a.array), Block(I), a.array.indices)...] =
value
return parent_blocks[I...][blockindices(
parent(a.array), Block(I), parentindices(a.array)
)...] = value
end
function Base.isassigned(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
if CartesianIndex(I) ∉ CartesianIndices(a)
Expand All @@ -313,7 +316,53 @@ function SparseArrayInterface.sparse_storage(a::SparseSubArrayBlocks)
return error("Not implemented")
end

# An alternative to `SubArray` where the blocking
# is determined from the parent.
# See https://github.com/JuliaArrays/BlockArrays.jl/issues/347.
struct BlockedSubArray{T,N,P,I} <: AbstractBlockArray{T,N}
parent::P
indices::I
function BlockedSubArray(parent, indices)
return new{eltype(parent),ndims(parent),typeof(parent),typeof(indices)}(parent, indices)
end
end
Base.parent(a::BlockedSubArray) = getfield(a, :parent)
Base.parentindices(a::BlockedSubArray) = getfield(a, :indices)
to_subarray(a::BlockedSubArray) = view(parent(a), parentindices(a)...)
function Base.axes(a::BlockedSubArray)
return ntuple(ndims(a)) do dim
return only(axes(blocked_getindex(axes(parent(a), dim), parentindices(a)[dim])))
end
end
Base.size(a::BlockedSubArray) = map(length, axes(a))
function Base.getindex(a::BlockedSubArray{<:Any,N}, I::Vararg{Int,N}) where {N}
return to_subarray(a)[I...]
end

function blocked_view(
a::AbstractArray{<:Any,N}, indices::Vararg{AbstractUnitRange,N}
) where {N}
return BlockedSubArray(a, indices)
end

function blocked_view(a::AbstractArray{<:Any,N}, indices::Vararg{Any,N}) where {N}
return view(a, indices...)
end

function blocksparse_blocks(a::BlockedSubArray)
return SparseSubArrayBlocks(a)
end

# TODO: Restrict to:
# SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{AbstractVector{<:Integer}}}}
# and consider making a trait, like `is_blocked_slice`.
function blocksparse_blocks(a::SubArray)
return BlocksView(a)
end

function blocksparse_blocks(
a::SubArray{<:Any,<:Any,<:Any,<:Tuple{Vararg{AbstractVector{<:Block}}}}
)
return SparseSubArrayBlocks(a)
end

Expand Down
18 changes: 9 additions & 9 deletions NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,10 @@ include("TestBlockSparseArraysUtils.jl")
@test blocksize(b) == (2, 2)
@test nstored(b) == nstored(a)
@test block_nstored(b) == 2
b_view = @view a[[Block(2), Block(1)], [Block(2), Block(1)]]

# TODO: Fix this!
@test_broken show(devnull, MIME("text/plain"), b_view)

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
Expand Down Expand Up @@ -182,9 +186,9 @@ include("TestBlockSparseArraysUtils.jl")
b = a[2:4, 2:4]
@test b == Array(a)[2:4, 2:4]
@test size(b) == (3, 3)
@test blocksize(b) == (2, 2)
@test nstored(b) == 1 * 1 + 2 * 2
@test block_nstored(b) == 2
@test blocksize(b) == (1, 1)
@test nstored(b) == length(b)
@test block_nstored(b) == blocklength(b)

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
a[Block(1, 2)] = randn(elt, size(@view(a[Block(1, 2)])))
Expand Down Expand Up @@ -257,18 +261,14 @@ include("TestBlockSparseArraysUtils.jl")
@view(a[Block(2, 2)])[1:1, 1:2] = x
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
@test a[Block(2, 2)][1:1, 1:2] == x

# TODO: This is broken, fix!
@test_broken a[3:3, 4:5] == x
@test a[3:3, 4:5] == x

a = BlockSparseArray{elt}(undef, ([2, 3], [3, 4]))
x = randn(elt, 1, 2)
@views a[Block(2, 2)][1:1, 1:2] = x
@test @view(a[Block(2, 2)])[1:1, 1:2] == x
@test a[Block(2, 2)][1:1, 1:2] == x

# TODO: This is broken, fix!
@test_broken a[3:3, 4:5] == x
@test a[3:3, 4:5] == x

## Broken, need to fix.

Expand Down
21 changes: 14 additions & 7 deletions NDTensors/src/lib/GradedAxes/src/gradedunitrange.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,20 @@ using BlockArrays: block, blockindex
function blockedunitrange_getindices(
a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer}
)
return indices
end

# TODO: Move this to a `BlockArraysExtensions` library.
# Slice a BlockedUnitRange, preserving the blocking.
# See https://github.com/JuliaArrays/BlockArrays.jl/issues/347.
function blocked_getindex(a::AbstractUnitRange, indices)
return a[indices]
end

# TODO: Move this to a `BlockArraysExtensions` library.
# Slice a BlockedUnitRange, preserving the blocking.
# See https://github.com/JuliaArrays/BlockArrays.jl/issues/347.
function blocked_getindex(a::BlockedUnitRange, indices::AbstractUnitRange{<:Integer})
first_blockindex = blockedunitrange_findblockindex(a, first(indices))
last_blockindex = blockedunitrange_findblockindex(a, last(indices))
first_block = block(first_blockindex)
Expand Down Expand Up @@ -247,13 +261,6 @@ function blocklabels(a::AbstractUnitRange, indices)
end
end

function blockedunitrange_getindices(
ga::GradedUnitRange, indices::AbstractUnitRange{<:Integer}
)
a_indices = blockedunitrange_getindices(unlabel_blocks(ga), indices)
return labelled_blocks(a_indices, blocklabels(ga, indices))
end

function blockedunitrange_getindices(ga::GradedUnitRange, indices::BlockRange)
return labelled_blocks(unlabel_blocks(ga)[indices], blocklabels(ga, indices))
end
Expand Down
27 changes: 4 additions & 23 deletions NDTensors/src/lib/GradedAxes/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,15 @@ using Test: @test, @test_broken, @testset
# Slicing operations
x = gradedrange(["x" => 2, "y" => 3])
a = x[2:4]
@test a isa GradedUnitRange
@test a isa AbstractUnitRange
@test length(a) == 3
@test blocklength(a) == 2
@test a[Block(1)] == 2:2
@test label(a[Block(1)]) == "x"
@test a[Block(2)] == 3:4
@test label(a[Block(2)]) == "y"
@test isone(first(only(axes(a))))
@test length(only(axes(a))) == length(a)
@test blocklengths(only(axes(a))) == blocklengths(a)
@test a == 2:4

x = gradedrange(["x" => 2, "y" => 3])
a = x[3:4]
@test a isa GradedUnitRange
@test length(a) == 2
@test blocklength(a) == 1
@test a[Block(1)] == 3:4
@test label(a[Block(1)]) == "y"

x = gradedrange(["x" => 2, "y" => 3])
a = x[2:4][1:2]
@test a isa GradedUnitRange
@test a isa AbstractUnitRange
@test length(a) == 2
@test blocklength(a) == 2
@test a[Block(1)] == 2:2
@test label(a[Block(1)]) == "x"
@test a[Block(2)] == 3:3
@test label(a[Block(2)]) == "y"
@test a == 3:4

x = gradedrange(["x" => 2, "y" => 3])
a = x[Block(2)[2:3]]
Expand Down
3 changes: 1 addition & 2 deletions NDTensors/src/lib/GradedAxes/test/test_dual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ GradedAxes.dual(c::U1) = U1(-c.n)
@test ad[4] == 4
@test label(ad[4]) == U1(-1)
@test ad[2:4] == 2:4
@test ad[2:4] isa UnitRangeDual
@test label(ad[2:4][Block(2)]) == U1(-1)
@test ad[2:4] isa AbstractUnitRange
@test ad[[2, 4]] == [2, 4]
@test label(ad[[2, 4]][2]) == U1(-1)
@test ad[Block(2)] == 3:5
Expand Down
Loading