Skip to content

Commit

Permalink
Fix broken SVD tests (#33)
Browse files Browse the repository at this point in the history
* reenable broken tests

* use `isstored` instead of `I in eachstoredindex`

* Bump version and minimal SparseArraysBase compat

* Formatter
  • Loading branch information
lkdvos authored Jan 22, 2025
1 parent 60ec997 commit 37fecf7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.2.10"
version = "0.2.11"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -42,7 +42,7 @@ LabelledNumbers = "0.1.0"
LinearAlgebra = "1.10"
MacroTools = "0.5.13"
MapBroadcast = "0.1.5"
SparseArraysBase = "0.2.2"
SparseArraysBase = "0.2.10"
SplitApplyCombine = "1.2.3"
TensorAlgebra = "0.1.0"
Test = "1.10"
Expand Down
7 changes: 2 additions & 5 deletions src/abstractblocksparsearray/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,8 @@ function BlockArrays.viewblock(
a::AbstractBlockSparseArray{<:Any,N}, block::Vararg{Block{1},N}
) where {N}
I = CartesianIndex(Int.(block))
# TODO: Use `eachblockstoredindex`.
if I eachstoredindex(blocks(a))
return blocks(a)[I]
end
return BlockView(a, block)
# TODO: isblockstored
return isstored(blocks(a), I) ? blocks(a)[I] : BlockView(a, block)
end

# Specialized code for getting the view of a subblock.
Expand Down
33 changes: 12 additions & 21 deletions test/test_svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ using LinearAlgebra: LinearAlgebra
using Random: Random
using Test: @inferred, @testset, @test

function test_svd(a, usv; broken=false)
function test_svd(a, usv)
U, S, V = usv
@test U * diagonal(S) * V' a broken = broken
@test U' * U LinearAlgebra.I
@test V' * V LinearAlgebra.I
return (U * diagonal(S) * V' a) &&
(U' * U LinearAlgebra.I) &&
(V' * V LinearAlgebra.I)
end

# regular matrix
Expand All @@ -19,7 +19,7 @@ eltypes = (Float32, Float64, ComplexF64)
@testset "($m, $n) Matrix{$T}" for ((m, n), T) in Iterators.product(sizes, eltypes)
a = rand(m, n)
usv = @inferred svd(a)
test_svd(a, usv)
@test test_svd(a, usv)
end

# block matrix
Expand All @@ -28,7 +28,7 @@ blockszs = (([2, 2], [2, 2]), ([2, 2], [2, 3]), ([2, 2, 1], [2, 3]), ([2, 3], [2
@testset "($m, $n) BlockMatrix{$T}" for ((m, n), T) in Iterators.product(blockszs, eltypes)
a = mortar([rand(T, i, j) for i in m, j in n])
usv = svd(a)
test_svd(a, usv)
@test test_svd(a, usv)
@test usv.U isa BlockedMatrix
@test usv.Vt isa BlockedMatrix
@test usv.S isa BlockedVector
Expand All @@ -39,17 +39,8 @@ end
@testset "($m, $n) BlockDiagonal{$T}" for ((m, n), T) in
Iterators.product(blockszs, eltypes)
a = BlockDiagonal([rand(T, i, j) for (i, j) in zip(m, n)])
if VERSION v"1.11"
usv = svd(a)
# TODO: `BlockDiagonal * Adjoint` errors
# TODO: This is broken because of https://github.com/JuliaLang/julia/issues/57034,
# fix and reenable.
test_svd(a, usv; broken=true)
else
# `svd(a)` depends on `diagind(::AbstractMatrix, ::IndexStyle)`
# being defined, but it was only introduced in Julia v1.11.
@test svd(a) broken = true
end
usv = svd(a)
@test test_svd(a, usv)
end

# blocksparse
Expand All @@ -60,25 +51,25 @@ end

# test empty matrix
usv_empty = svd(a)
test_svd(a, usv_empty)
@test test_svd(a, usv_empty)

# test blockdiagonal
for i in LinearAlgebra.diagind(blocks(a))
I = CartesianIndices(blocks(a))[i]
a[Block(I.I...)] = rand(T, size(blocks(a)[i]))
end
usv = svd(a)
test_svd(a, usv)
@test test_svd(a, usv)

perm = Random.randperm(length(m))
b = a[Block.(perm), Block.(1:length(n))]
usv = svd(b)
test_svd(b, usv)
@test test_svd(b, usv)

# test permuted blockdiagonal with missing row/col
I_removed = rand(eachblockstoredindex(b))
c = copy(b)
delete!(blocks(c).storage, CartesianIndex(Int.(Tuple(I_removed))))
usv = svd(c)
test_svd(c, usv)
@test test_svd(c, usv)
end

0 comments on commit 37fecf7

Please sign in to comment.