Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations on to_indices #227

Merged
merged 13 commits into from
Nov 28, 2021
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ArrayInterface"
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
version = "3.1.40"
version = "3.2"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ ArrayInterface.fast_scalar_indexing
ArrayInterface.has_dimnames
ArrayInterface.has_parent
ArrayInterface.has_sparsestruct
ArrayInterface.is_canonical
ArrayInterface.is_column_major
ArrayInterface.is_lazy_conjugate
ArrayInterface.ismutable
ArrayInterface.issingular
ArrayInterface.isstructured
ArrayInterface.is_splat_index
ArrayInterface.known_first
ArrayInterface.known_last
ArrayInterface.known_length
Expand All @@ -31,6 +31,7 @@ ArrayInterface.known_offsets
ArrayInterface.known_size
ArrayInterface.known_step
ArrayInterface.known_strides
ArrayInterface.ndims_index
```

## Functions
Expand All @@ -43,7 +44,6 @@ ArrayInterface.axes
ArrayInterface.axes_types
ArrayInterface.broadcast_axis
ArrayInterface.buffer
ArrayInterface.canonicalize
ArrayInterface.deleteat
ArrayInterface.dense_dims
ArrayInterface.findstructralnz
Expand Down
17 changes: 17 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ end
```

Most traits in `ArrayInterface` are a variant on this pattern.
If the trait in question may be altered by a wrapper array, this pattern should be altered or may be inappropriate.

## Static Traits

Expand Down Expand Up @@ -174,3 +175,19 @@ Defining these two methods ensures that other array types that wrap `OffsetArray
It is entirely optional to define `ArrayInterface.size` for `OffsetArray` because the size can be derived from the axes.
However, in this particularly case we should also define
`ArrayInterface.size(A::OffsetArray) = ArrayInterface.size(parent(A))` because the relative offsets attached to `OffsetArray` do not change the size but may hide static sizes if using a relative offset that is defined with an `Int`.

## Processing Indices (`to_indices`)

For most users, the only reason you should use `ArrayInterface.to_indices` over `Base.to_indices` is that it's faster and perhaps some of the more detailed benefits described in the [`to_indices`](@ref) doc string.
For those interested in how this is accomplished, the following steps (beginning with the `to_indices(A::AbstractArray, I::Tuple)`) are used to accomplish this:

1. The number of dimensions that each indexing argument in `I` corresponds to is determined using using the [`ndims_index`](@ref) and [`is_splat_index`](@ref) traits.
2. A non-allocating reference to each axis of `A` is created (`lazy_axes(A) -> axs`). These are aligned to each the index arguments using information from the first step. For example, if an index argument maps to a single dimension then it is paired with `axs[dim]`. In the case of multiple dimensions it is paired with `CartesianIndices(axs[dim_1], ... axs[dim_n])`. These pairs are further processed using `to_index(axis, I[n])`.
3. Tuples returned from `to_index` are flattened out so that there are no nested tuples returned from `to_indices`.

Entry points:

* `to_indices(::ArrayType, indices)` : dispatch on unique array type `ArrayType`
* `to_index(axis, ::IndexType)` : dispatch on a unique indexing type, `IndexType`. `ArrayInterface.ndims_index(::Type{IndexType})` should also be defined in this case.
* `to_index(S::IndexStyle, axis, index)` : The index style `S` that corresponds to `axis`. This is

10 changes: 4 additions & 6 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr
ReshapedArray, AbstractCartesianIndex

const CanonicalInt = Union{Int,StaticInt}
canonicalize(x::Integer) = Int(x)
canonicalize(@nospecialize(x::StaticInt)) = x

@static if isdefined(Base, :ReshapedReinterpretArray)
_is_reshaped(::Type{<:Base.ReshapedReinterpretArray}) = true
Expand All @@ -29,8 +31,6 @@ parameterless_type(x::Type) = __parameterless_type(x)

const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
const UpTri{T,M} = Union{UpperTriangular{T,M},UnitUpperTriangular{T,M}}
const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}

@inline static_length(a::UnitRange{T}) where {T} = last(a) - first(a) + oneunit(T)
@inline static_length(x) = Static.maybe_static(known_length, length, x)
Expand All @@ -55,15 +55,13 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
parent_type(::Type{Slice{T}}) where {T} = T
parent_type(::Type{T}) where {T} = T
parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A
parent_type(::Type{LoTri{T,M}}) where {T,M} = M
parent_type(::Type{UpTri{T,M}}) where {T,M} = M
parent_type(::Type{R}) where {S,T,A,N,R<:ReinterpretArray{T,N,S,A}} = A
parent_type(::Type{Diagonal{T,V}}) where {T,V} = V

"""
has_parent(::Type{T}) -> StaticBool

Returns `True` if `parent_type(T)` a type unique to `T`.
Returns `static(true)` if `parent_type(T)` a type unique to `T`.
"""
has_parent(x) = has_parent(typeof(x))
has_parent(::Type{T}) where {T} = _has_parent(parent_type(T), T)
Expand Down
17 changes: 2 additions & 15 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,6 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
end
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()

#=
ndims_index(::Type{I})::StaticInt

The number of dimensions an instance of `I` maps to when indexing an instance of `A`.
=#
ndims_index(i) = ndims_index(typeof(i))
ndims_index(::Type{I}) where {I} = static(1)
ndims_index(::Type{I}) where {N,I<:AbstractCartesianIndex{N}} = static(N)
ndims_index(::Type{I}) where {I<:AbstractArray} = ndims_index(eltype(I))
ndims_index(::Type{I}) where {I<:AbstractArray{Bool}} = static(ndims(I))
ndims_index(::Type{I}) where {N,I<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = static(N)
_ndims_index(::Type{I}, i::StaticInt) where {I} = ndims_index(_get_tuple(I, i))
ndims_index(::Type{I}) where {N,I<:Tuple{Vararg{Any,N}}} = eachop(_ndims_index, nstatic(Val(N)), I)

"""
from_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
from_parent_dims(::Type{T}, dim) -> Union{Int,StaticInt}
Expand Down Expand Up @@ -191,7 +177,8 @@ end
This returns the dimension(s) of `x` corresponding to `d`.
"""
to_dims(x, dim) = to_dims(typeof(x), dim)
to_dims(::Type{T}, dim::Integer) where {T} = canonicalize(dim)
to_dims(::Type{T}, dim::StaticInt) where {T} = dim
to_dims(::Type{T}, dim::Integer) where {T} = Int(dim)
to_dims(::Type{T}, dim::Colon) where {T} = dim
function to_dims(::Type{T}, dim::StaticSymbol) where {T}
i = find_first_eq(dim, dimnames(T))
Expand Down
Loading