Skip to content

Commit

Permalink
Reduce the number of convert methods
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Mar 22, 2024
1 parent e08d18f commit 20e1716
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
14 changes: 6 additions & 8 deletions src/blockaxis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ blockisequal(a::Tuple, b::Tuple) = all(blockisequal.(a, b))


Base.convert(::Type{BlockedUnitRange}, axis::BlockedUnitRange) = axis
Base.convert(::Type{BlockedUnitRange}, axis::AbstractBlockedUnitRange) = _BlockedUnitRange(first(axis), blocklasts(axis))
Base.convert(::Type{BlockedUnitRange}, axis::AbstractBlockedUnitRange{Int}) = _BlockedUnitRange(first(axis), blocklasts(axis))
Base.convert(::Type{BlockedUnitRange}, axis::AbstractUnitRange{Int}) = _BlockedUnitRange(first(axis),last(axis):last(axis))
Base.convert(::Type{BlockedUnitRange{CS}}, axis::BlockedUnitRange{CS}) where CS = axis
Base.convert(::Type{BlockedUnitRange{CS}}, axis::BlockedUnitRange) where CS = _BlockedUnitRange(first(axis), convert(CS, blocklasts(axis)))
Expand All @@ -216,16 +216,14 @@ Base.unitrange(b::AbstractBlockedUnitRange) = first(b):last(b)

Base.promote_rule(::Type{<:AbstractBlockedUnitRange}, ::Type{Base.OneTo{Int}}) = UnitRange{Int}

Base.convert(::Type{BlockedOneTo}, axis::BlockedOneTo) = axis
_convert(::Type{BlockedOneTo}, axis::AbstractBlockedUnitRange) = BlockedOneTo(blocklasts(axis))
_convert(::Type{BlockedOneTo}, axis::AbstractUnitRange{Int}) = BlockedOneTo(last(axis):last(axis))
function Base.convert(::Type{BlockedOneTo}, axis::AbstractUnitRange{Int})
first(axis) == 1 || throw(ArgumentError("first element of range is not 1"))
_convert(BlockedOneTo, axis)
BlockedOneTo(blocklasts(axis))
end
function Base.convert(::Type{BlockedOneTo{CS}}, axis::AbstractUnitRange{Int}) where CS
first(axis) == 1 || throw(ArgumentError("first element of range is not 1"))
BlockedOneTo(convert(CS, blocklasts(axis)))
end
Base.convert(::Type{BlockedOneTo{CS}}, axis::BlockedOneTo{CS}) where CS = axis
Base.convert(::Type{BlockedOneTo{CS}}, axis::BlockedOneTo) where CS = BlockedOneTo(convert(CS, blocklasts(axis)))
Base.convert(::Type{BlockedOneTo{CS}}, axis::AbstractUnitRange{Int}) where CS = convert(BlockedOneTo{CS}, convert(BlockedOneTo, axis))

"""
blockaxes(A::AbstractArray)
Expand Down
19 changes: 13 additions & 6 deletions test/test_blockindices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,20 @@ end
@testset "convert" begin
b = blockedrange(Fill(2,3))
c = blockedrange([2,2,2])
@test convert(BlockedOneTo, b) === b
@test oftype(b, b) === b
@test blockisequal(convert(BlockedOneTo, Base.OneTo(5)), blockedrange([5]))
@test blockisequal(convert(BlockedOneTo, Base.Slice(Base.OneTo(5))), blockedrange([5]))
@test convert(BlockedOneTo{Vector{Int}}, c) === c
@test blockisequal(convert(BlockedOneTo{Vector{Int}}, b),b)
@test blockisequal(convert(BlockedOneTo{Vector{Int}}, Base.OneTo(5)), blockedrange([5]))
@test blockisequal(convert(BlockedOneTo, blockedrange(1, [1,1,1])), blockedrange([1,1,1]))
@test convert(BlockedOneTo, c) === c
@test convert(typeof(c), c) === c
function test_type_and_blockequal(T, r, res)
s = convert(T, r)
@test s isa T
@test blockisequal(r, res)
end
test_type_and_blockequal(BlockedOneTo, Base.OneTo(5), blockedrange([5]))
test_type_and_blockequal(BlockedOneTo, Base.Slice(Base.OneTo(5)), blockedrange([5]))
test_type_and_blockequal(BlockedOneTo{Vector{Int}}, b, b)
test_type_and_blockequal(BlockedOneTo{Vector{Int}}, Base.OneTo(5), blockedrange([5]))
test_type_and_blockequal(BlockedOneTo, blockedrange(1, [1,1,1]), blockedrange([1,1,1]))
end

@testset "findblock" begin
Expand Down

0 comments on commit 20e1716

Please sign in to comment.