Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use fallback for reshape/cat OneHotArray #1459

Merged
merged 5 commits into from
Jan 29, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,30 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

OneHotVector(idx, L) = OneHotArray(idx, L)
OneHotMatrix(indices, L) = OneHotArray(indices, L)

# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Man this N+1 is tripping me up, I would say we need to remove this soon. Where is it used exactly?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we could calculate var"N+1" during runtime?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it either! It can't be done at runtime since N and var"N+1" are used in the type specification. N is used to specify the type of the index array, and var"N+1" is used to inherit from AbstractArray{Bool, var"N+1"}. Neither is evaluated at runtime.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could change it to another variable. I don't have strong feelings, but a part of me says that at least this naming signals the intent of the type parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair, I did mean we would have to switch it out during construction, because I don't think it's any better for dispatch to have to do checks on ints than types. To me it suggests that it is a preknown quantity so adding it to the type doesn't win us much.

Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L}}}

# when reshaping a OneHotArray and first(dims) != L
# convert the parent array to Array{Bool}
# so that the ReshapedArray does not hit fast paths
function Base.ReshapedArray(parent::OneHotArray{<:Any, L}, dims::NTuple{N,Int}, mi) where {L, N}
parent = (first(dims) != L) ? convert(_onehot_bool_type(parent), parent) : parent
darsnack marked this conversation as resolved.
Show resolved Hide resolved

Base.ReshapedArray{Bool,N,typeof(parent),typeof(mi)}(parent, dims, mi)
end

Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)

_onehotindex(x, i) = (x == i)
Expand All @@ -24,37 +41,32 @@ Base.getindex(x::OneHotVector, i::Integer) = _onehotindex(x.indices, i)
Base.getindex(x::OneHotVector{T, L}, ::Colon) where {T, L} = x

Base.getindex(x::OneHotArray, i::Integer, I...) = _onehotindex.(x.indices[I...], i)
Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.indices[I...], L)
Base.getindex(x::OneHotLike{<:Any, L}, ::Colon, I...) where L = OneHotArray(_indices(x)[I...], L)
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]

_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims)
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
else
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
end
end

Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)

Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
Adapt.adapt_structure(T, x::OneHotLike{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()
Base.BroadcastStyle(::Type{<:OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()

Base.argmax(x::OneHotArray; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
Base.argmax(x::OneHotLike; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)

"""
Expand Down Expand Up @@ -135,11 +147,11 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
end

_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = x.indices
_fast_argmax(x::OneHotLike) = _indices(x)

@nograd OneHotArray, onecold, onehot, onehotbatch

function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
darsnack marked this conversation as resolved.
Show resolved Hide resolved
return A[:, onecold(B)]
end
20 changes: 10 additions & 10 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
end

@testset "OneHotArray" begin
using Flux: OneHotArray, OneHotVector, OneHotMatrix
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike

ov = OneHotVector(rand(1:10), 10)
om = OneHotMatrix(rand(1:10, 5), 10)
Expand Down Expand Up @@ -74,27 +74,27 @@ end
@testset "Concatenating" begin
# vector cat
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
@test_throws ArgumentError vcat(ov, ov)
@test vcat(ov, ov) == vcat(collect(ov), collect(ov))
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)

# matrix cat
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
@test_throws ArgumentError vcat(om, om)
@test vcat(om, om) == vcat(collect(om), collect(om))
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)

# array cat
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
@test_throws ArgumentError cat(oa, oa; dims = 1)
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)
end

@testset "Base.reshape" begin
# reshape test
@test reshape(oa, 10, 25) isa OneHotArray
@test reshape(oa, 10, :) isa OneHotArray
@test reshape(oa, :, 25) isa OneHotArray
@test_throws ArgumentError reshape(oa, 50, :)
@test_throws ArgumentError reshape(oa, 5, 10, 5)
@test reshape(oa, (10, 25)) isa OneHotArray
@test reshape(oa, 10, 25) isa OneHotLike
@test reshape(oa, 10, :) isa OneHotLike
@test reshape(oa, :, 25) isa OneHotLike
@test reshape(oa, 50, :) isa Base.ReshapedArray{<:Any, <:Any, <:Array}
@test reshape(oa, 5, 10, 5) isa Base.ReshapedArray{<:Any, <:Any, <:Array}
@test reshape(oa, (10, 25)) isa OneHotLike
end

@testset "Base.argmax" begin
Expand Down