Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Nov 26, 2024
1 parent b35c964 commit 8a22fe1
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 164 deletions.
263 changes: 119 additions & 144 deletions src/NestedPermutedDimsArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,19 @@ export NestedPermutedDimsArray

# Some day we will want storage-order-aware iteration, so put perm in the parameters
struct NestedPermutedDimsArray{T,N,perm,iperm,AA<:AbstractArray} <: AbstractArray{T,N}
parent::AA
parent::AA

function NestedPermutedDimsArray{T,N,perm,iperm,AA}(
data::AA,
) where {T,N,perm,iperm,AA<:AbstractArray}
(isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) ||
error("perm and iperm must both be NTuple{$N,Int}")
isperm(perm) || throw(
ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)),
)
all(d -> iperm[perm[d]] == d, 1:N) ||
throw(ArgumentError(string(perm, " and ", iperm, " must be inverses")))
return new(data)
end
function NestedPermutedDimsArray{T,N,perm,iperm,AA}(
data::AA
) where {T,N,perm,iperm,AA<:AbstractArray}
(isa(perm, NTuple{N,Int}) && isa(iperm, NTuple{N,Int})) ||
error("perm and iperm must both be NTuple{$N,Int}")
isperm(perm) ||
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
all(d -> iperm[perm[d]] == d, 1:N) ||
throw(ArgumentError(string(perm, " and ", iperm, " must be inverses")))
return new(data)
end
end

## TODO: Fix this docstring.
Expand All @@ -87,37 +86,35 @@ end
## ```
## """
Base.@constprop :aggressive function NestedPermutedDimsArray(
data::AbstractArray{T,N},
perm,
data::AbstractArray{T,N}, perm
) where {T,N}
length(perm) == N || throw(
ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)),
)
iperm = invperm(perm)
return NestedPermutedDimsArray{
PermutedDimsArray{eltype(T),N,(perm...,),(iperm...,),T},
N,
(perm...,),
(iperm...,),
typeof(data),
}(
data,
)
length(perm) == N ||
throw(ArgumentError(string(perm, " is not a valid permutation of dimensions 1:", N)))
iperm = invperm(perm)
return NestedPermutedDimsArray{
PermutedDimsArray{eltype(T),N,(perm...,),(iperm...,),T},
N,
(perm...,),
(iperm...,),
typeof(data),
}(
data
)
end

Base.parent(A::NestedPermutedDimsArray) = A.parent
function Base.size(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
return genperm(size(parent(A)), perm)
return genperm(size(parent(A)), perm)
end
function Base.axes(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
return genperm(axes(parent(A)), perm)
return genperm(axes(parent(A)), perm)
end
Base.has_offset_axes(A::NestedPermutedDimsArray) = Base.has_offset_axes(A.parent)
function Base.similar(A::NestedPermutedDimsArray, T::Type, dims::Base.Dims)
return similar(parent(A), T, dims)
return similar(parent(A), T, dims)
end
function Base.cconvert(::Type{Ptr{T}}, A::NestedPermutedDimsArray{T}) where {T}
return Base.cconvert(Ptr{T}, parent(A))
return Base.cconvert(Ptr{T}, parent(A))
end

# It's OK to return a pointer to the first element, and indeed quite
Expand All @@ -126,177 +123,155 @@ end
# storage order, a linear offset is ambiguous---is it a memory offset
# or a linear index?
function Base.pointer(A::NestedPermutedDimsArray, i::Integer)
throw(
ArgumentError(
"pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray",
),
)
throw(
ArgumentError("pointer(A, i) is deliberately unsupported for NestedPermutedDimsArray")
)
end

function Base.strides(A::NestedPermutedDimsArray{T,N,perm}) where {T,N,perm}
s = strides(parent(A))
return ntuple(d -> s[perm[d]], Val(N))
s = strides(parent(A))
return ntuple(d -> s[perm[d]], Val(N))
end
function Base.elsize(::Type{<:NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,P}}) where {P}
return Base.elsize(P)
return Base.elsize(P)
end

@inline function Base.getindex(
A::NestedPermutedDimsArray{T,N,perm,iperm},
I::Vararg{Int,N},
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(A, I...)
@inbounds val = PermutedDimsArray(getindex(A.parent, genperm(I, iperm)...), perm)
return val
@boundscheck checkbounds(A, I...)
@inbounds val = PermutedDimsArray(getindex(A.parent, genperm(I, iperm)...), perm)
return val
end
@inline function Base.setindex!(
A::NestedPermutedDimsArray{T,N,perm,iperm},
val,
I::Vararg{Int,N},
A::NestedPermutedDimsArray{T,N,perm,iperm}, val, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(A, I...)
@inbounds setindex!(A.parent, PermutedDimsArray(val, iperm), genperm(I, iperm)...)
return val
@boundscheck checkbounds(A, I...)
@inbounds setindex!(A.parent, PermutedDimsArray(val, iperm), genperm(I, iperm)...)
return val
end

function Base.isassigned(
A::NestedPermutedDimsArray{T,N,perm,iperm},
I::Vararg{Int,N},
A::NestedPermutedDimsArray{T,N,perm,iperm}, I::Vararg{Int,N}
) where {T,N,perm,iperm}
@boundscheck checkbounds(Bool, A, I...) || return false
@inbounds x = isassigned(A.parent, genperm(I, iperm)...)
return x
@boundscheck checkbounds(Bool, A, I...) || return false
@inbounds x = isassigned(A.parent, genperm(I, iperm)...)
return x
end

@inline genperm(I::NTuple{N,Any}, perm::Dims{N}) where {N} = ntuple(d -> I[perm[d]], Val(N))
@inline genperm(I, perm::AbstractVector{Int}) = genperm(I, (perm...,))

function Base.copyto!(
dest::NestedPermutedDimsArray{T,N},
src::AbstractArray{T,N},
dest::NestedPermutedDimsArray{T,N}, src::AbstractArray{T,N}
) where {T,N}
checkbounds(dest, axes(src)...)
return _copy!(dest, src)
checkbounds(dest, axes(src)...)
return _copy!(dest, src)
end
Base.copyto!(dest::NestedPermutedDimsArray, src::AbstractArray) = _copy!(dest, src)

function _copy!(P::NestedPermutedDimsArray{T,N,perm}, src) where {T,N,perm}
# If dest/src are "close to dense," then it pays to be cache-friendly.
# Determine the first permuted dimension
d = 0 # d+1 will hold the first permuted dimension of src
while d < ndims(src) && perm[d+1] == d + 1
d += 1
end
if d == ndims(src)
copyto!(parent(P), src) # it's not permuted
else
R1 = CartesianIndices(axes(src)[1:d])
d1 = findfirst(isequal(d + 1), perm)::Int # first permuted dim of dest
R2 = CartesianIndices(axes(src)[(d+2):(d1-1)])
R3 = CartesianIndices(axes(src)[(d1+1):end])
_permutedims!(P, src, R1, R2, R3, d + 1, d1)
end
return P
# If dest/src are "close to dense," then it pays to be cache-friendly.
# Determine the first permuted dimension
d = 0 # d+1 will hold the first permuted dimension of src
while d < ndims(src) && perm[d + 1] == d + 1
d += 1
end
if d == ndims(src)
copyto!(parent(P), src) # it's not permuted
else
R1 = CartesianIndices(axes(src)[1:d])
d1 = findfirst(isequal(d + 1), perm)::Int # first permuted dim of dest
R2 = CartesianIndices(axes(src)[(d + 2):(d1 - 1)])
R3 = CartesianIndices(axes(src)[(d1 + 1):end])
_permutedims!(P, src, R1, R2, R3, d + 1, d1)
end
return P
end

@noinline function _permutedims!(
P::NestedPermutedDimsArray,
src,
R1::CartesianIndices{0},
R2,
R3,
ds,
dp,
P::NestedPermutedDimsArray, src, R1::CartesianIndices{0}, R2, R3, ds, dp
)
ip, is = axes(src, dp), axes(src, ds)
for jo = first(ip):8:last(ip), io = first(is):8:last(is)
for I3 in R3, I2 in R2
for j = jo:min(jo + 7, last(ip))
for i = io:min(io + 7, last(is))
@inbounds P[i, I2, j, I3] = src[i, I2, j, I3]
end
end
ip, is = axes(src, dp), axes(src, ds)
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
for I3 in R3, I2 in R2
for j in jo:min(jo + 7, last(ip))
for i in io:min(io + 7, last(is))
@inbounds P[i, I2, j, I3] = src[i, I2, j, I3]
end
end
end
return P
end
return P
end

@noinline function _permutedims!(P::NestedPermutedDimsArray, src, R1, R2, R3, ds, dp)
ip, is = axes(src, dp), axes(src, ds)
for jo = first(ip):8:last(ip), io = first(is):8:last(is)
for I3 in R3, I2 in R2
for j = jo:min(jo + 7, last(ip))
for i = io:min(io + 7, last(is))
for I1 in R1
@inbounds P[I1, i, I2, j, I3] = src[I1, i, I2, j, I3]
end
end
end
ip, is = axes(src, dp), axes(src, ds)
for jo in first(ip):8:last(ip), io in first(is):8:last(is)
for I3 in R3, I2 in R2
for j in jo:min(jo + 7, last(ip))
for i in io:min(io + 7, last(is))
for I1 in R1
@inbounds P[I1, i, I2, j, I3] = src[I1, i, I2, j, I3]
end
end
end
end
return P
end
return P
end

const CommutativeOps = Union{
typeof(+),
typeof(Base.add_sum),
typeof(min),
typeof(max),
typeof(Base._extrema_rf),
typeof(|),
typeof(&),
typeof(+),
typeof(Base.add_sum),
typeof(min),
typeof(max),
typeof(Base._extrema_rf),
typeof(|),
typeof(&),
}

function Base._mapreduce_dim(
f,
op::CommutativeOps,
init::Base._InitialValue,
A::NestedPermutedDimsArray,
dims::Colon,
f, op::CommutativeOps, init::Base._InitialValue, A::NestedPermutedDimsArray, dims::Colon
)
return Base._mapreduce_dim(f, op, init, parent(A), dims)
return Base._mapreduce_dim(f, op, init, parent(A), dims)
end
function Base._mapreduce_dim(
f::typeof(identity),
op::Union{typeof(Base.mul_prod),typeof(*)},
init::Base._InitialValue,
A::NestedPermutedDimsArray{<:Union{Real,Complex}},
dims::Colon,
f::typeof(identity),
op::Union{typeof(Base.mul_prod),typeof(*)},
init::Base._InitialValue,
A::NestedPermutedDimsArray{<:Union{Real,Complex}},
dims::Colon,
)
return Base._mapreduce_dim(f, op, init, parent(A), dims)
return Base._mapreduce_dim(f, op, init, parent(A), dims)
end

function Base.mapreducedim!(
f,
op::CommutativeOps,
B::AbstractArray{T,N},
A::NestedPermutedDimsArray{S,N,perm,iperm},
f, op::CommutativeOps, B::AbstractArray{T,N}, A::NestedPermutedDimsArray{S,N,perm,iperm}
) where {T,S,N,perm,iperm}
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
return B
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
return B
end
function Base.mapreducedim!(
f::typeof(identity),
op::Union{typeof(Base.mul_prod),typeof(*)},
B::AbstractArray{T,N},
A::NestedPermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
f::typeof(identity),
op::Union{typeof(Base.mul_prod),typeof(*)},
B::AbstractArray{T,N},
A::NestedPermutedDimsArray{<:Union{Real,Complex},N,perm,iperm},
) where {T,N,perm,iperm}
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
return B
C = NestedPermutedDimsArray{T,N,iperm,perm,typeof(B)}(B) # make the inverse permutation for the output
Base.mapreducedim!(f, op, C, parent(A))
return B
end

function Base.showarg(
io::IO,
A::NestedPermutedDimsArray{T,N,perm},
toplevel,
io::IO, A::NestedPermutedDimsArray{T,N,perm}, toplevel
) where {T,N,perm}
print(io, "NestedPermutedDimsArray(")
Base.showarg(io, parent(A), false)
print(io, ", ", perm, ')')
toplevel && print(io, " with eltype ", eltype(A))
return nothing
print(io, "NestedPermutedDimsArray(")
Base.showarg(io, parent(A), false)
print(io, ", ", perm, ')')
toplevel && print(io, " with eltype ", eltype(A))
return nothing
end

end
36 changes: 16 additions & 20 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,22 @@
using NestedPermutedDimsArrays: NestedPermutedDimsArray
using Test: @test, @testset
@testset "NestedPermutedDimsArrays" for elt in (
Float32,
Float64,
Complex{Float32},
Complex{Float64},
Float32, Float64, Complex{Float32}, Complex{Float64}
)
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
perm = (3, 1, 2)
p = NestedPermutedDimsArray(a, perm)
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
@test size(p) == (4, 2, 3)
@test eltype(p) === T
for I in eachindex(p)
@test size(p[I]) == (4, 2, 3)
@test p[I] ==
permutedims(a[CartesianIndex(map(i -> Tuple(I)[i], invperm(perm)))], perm)
end
x = randn(elt, 4, 2, 3)
p[3, 1, 2] = x
@test p[3, 1, 2] == x
@test a[1, 2, 3] == permutedims(x, invperm(perm))
a = map(_ -> randn(elt, 2, 3, 4), CartesianIndices((2, 3, 4)))
perm = (3, 1, 2)
p = NestedPermutedDimsArray(a, perm)
T = PermutedDimsArray{elt,3,perm,invperm(perm),eltype(a)}
@test typeof(p) === NestedPermutedDimsArray{T,3,perm,invperm(perm),typeof(a)}
@test size(p) == (4, 2, 3)
@test eltype(p) === T
for I in eachindex(p)
@test size(p[I]) == (4, 2, 3)
@test p[I] == permutedims(a[CartesianIndex(map(i -> Tuple(I)[i], invperm(perm)))], perm)
end
x = randn(elt, 4, 2, 3)
p[3, 1, 2] = x
@test p[3, 1, 2] == x
@test a[1, 2, 3] == permutedims(x, invperm(perm))
end
end

0 comments on commit 8a22fe1

Please sign in to comment.