Skip to content

Commit

Permalink
Merge #1459
Browse files Browse the repository at this point in the history
1459: Use fallback for reshape/cat OneHotArray r=DhairyaLGandhi a=darsnack

This falls back to reshaping a `Bool` array whenever reshaping the first dimension of a `OneHotArray`.

@DhairyaLGandhi @CarloLucibello @simeonschaub 

### PR Checklist

- [x] Tests are added
- [ ] ~~Entry in NEWS.md~~
- [x] Documentation, if applicable


Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2021
2 parents fdf7152 + d27139f commit 8f79161
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
50 changes: 32 additions & 18 deletions src/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,24 @@ OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, T}
OneHotArray(indices::AbstractArray{T, N}, L::Integer) where {T, N} = OneHotArray{T, L, N, typeof(indices)}(indices)

_indices(x::OneHotArray) = x.indices
_indices(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray}) =
reshape(parent(x).indices, x.dims[2:end])

const OneHotVector{T, L} = OneHotArray{T, L, 0, 1, T}
const OneHotMatrix{T, L, I} = OneHotArray{T, L, 1, 2, I}

OneHotVector(idx, L) = OneHotArray(idx, L)
OneHotMatrix(indices, L) = OneHotArray(indices, L)

# use this type so reshaped arrays hit fast paths
# e.g. argmax
const OneHotLike{T, L, N, var"N+1", I} =
Union{OneHotArray{T, L, N, var"N+1", I},
Base.ReshapedArray{Bool, var"N+1", <:OneHotArray{T, L, <:Any, <:Any, I}}}

_isonehot(x::OneHotArray) = true
_isonehot(x::Base.ReshapedArray{<:Any, <:Any, <:OneHotArray{<:Any, L}}) where L = (size(x, 1) == L)

Base.size(x::OneHotArray{<:Any, L}) where L = (Int(L), size(x.indices)...)

_onehotindex(x, i) = (x == i)
Expand All @@ -28,34 +39,30 @@ Base.getindex(x::OneHotArray{<:Any, L}, ::Colon, I...) where L = OneHotArray(x.i
Base.getindex(x::OneHotArray{<:Any, <:Any, <:Any, N}, ::Vararg{Colon, N}) where N = x
Base.getindex(x::OneHotArray, I::CartesianIndex{N}) where N = x[I[1], Tuple(I)[2:N]...]

_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

function Base.cat(xs::OneHotArray{<:Any, L}...; dims::Int) where L
if isone(dims)
return throw(ArgumentError("Cannot concat OneHotArray along first dimension. Use collect to convert to Bool array first."))
function Base.cat(xs::OneHotLike{<:Any, L}...; dims::Int) where L
if isone(dims) || any(x -> !_isonehot(x), xs)
return cat(map(x -> convert(_onehot_bool_type(x), x), xs)...; dims = dims)
else
return OneHotArray(cat(_indices.(xs)...; dims = dims - 1), L)
end
end

Base.hcat(xs::OneHotArray...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotArray...) = cat(xs...; dims = 1)

Base.reshape(x::OneHotArray{<:Any, L}, dims::Dims) where L =
(first(dims) == L) ? OneHotArray(reshape(x.indices, dims[2:end]...), L) :
throw(ArgumentError("Cannot reshape OneHotArray if first(dims) != size(x, 1)"))
Base._reshape(x::OneHotArray, dims::Tuple{Vararg{Int}}) = reshape(x, dims)
Base.hcat(xs::OneHotLike...) = cat(xs...; dims = 2)
Base.vcat(xs::OneHotLike...) = cat(xs...; dims = 1)

batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotArray(_indices.(xs), L)

Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, x.indices), L)
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)

Base.BroadcastStyle(::Type{<:OneHotArray{<:Any, <:Any, <:Any, N, <:CuArray}}) where N = CUDA.CuArrayStyle{N}()

Base.argmax(x::OneHotArray; dims = Colon()) =
(dims == 1) ? reshape(CartesianIndex.(x.indices, CartesianIndices(x.indices)), 1, size(x.indices)...) :
argmax(convert(_onehot_bool_type(x), x); dims = dims)
Base.argmax(x::OneHotLike; dims = Colon()) =
(_isonehot(x) && dims == 1) ?
reshape(CartesianIndex.(_indices(x), CartesianIndices(_indices(x))), 1, size(_indices(x))...) :
invoke(argmax, Tuple{AbstractArray}, x; dims = dims)

"""
onehot(l, labels[, unk])
Expand Down Expand Up @@ -135,11 +142,18 @@ function onecold(y::AbstractArray, labels = 1:size(y, 1))
end

_fast_argmax(x::AbstractArray) = dropdims(argmax(x; dims = 1); dims = 1)
_fast_argmax(x::OneHotArray) = x.indices
function _fast_argmax(x::OneHotLike)
if _isonehot(x)
return _indices(x)
else
return _fast_argmax(convert(_onehot_bool_type(x), x))
end
end

@nograd OneHotArray, onecold, onehot, onehotbatch

function Base.:(*)(A::AbstractMatrix, B::OneHotArray{<:Any, L}) where L
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
return A[:, onecold(B)]
end
32 changes: 22 additions & 10 deletions test/onehot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ end
end

@testset "OneHotArray" begin
using Flux: OneHotArray, OneHotVector, OneHotMatrix
using Flux: OneHotArray, OneHotVector, OneHotMatrix, OneHotLike

ov = OneHotVector(rand(1:10), 10)
om = OneHotMatrix(rand(1:10, 5), 10)
Expand Down Expand Up @@ -74,27 +74,39 @@ end
@testset "Concatenating" begin
# vector cat
@test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)
@test_throws ArgumentError vcat(ov, ov)
@test vcat(ov, ov) == vcat(collect(ov), collect(ov))
@test cat(ov, ov; dims = 3) == OneHotArray(cat(ov.indices, ov.indices; dims = 2), 10)

# matrix cat
@test hcat(om, om) == OneHotMatrix(vcat(om.indices, om.indices), 10)
@test_throws ArgumentError vcat(om, om)
@test vcat(om, om) == vcat(collect(om), collect(om))
@test cat(om, om; dims = 3) == OneHotArray(cat(om.indices, om.indices; dims = 2), 10)

# array cat
@test cat(oa, oa; dims = 3) == OneHotArray(cat(oa.indices, oa.indices; dims = 2), 10)
@test_throws ArgumentError cat(oa, oa; dims = 1)
@test cat(oa, oa; dims = 1) == cat(collect(oa), collect(oa); dims = 1)
end

@testset "Base.reshape" begin
# reshape test
@test reshape(oa, 10, 25) isa OneHotArray
@test reshape(oa, 10, :) isa OneHotArray
@test reshape(oa, :, 25) isa OneHotArray
@test_throws ArgumentError reshape(oa, 50, :)
@test_throws ArgumentError reshape(oa, 5, 10, 5)
@test reshape(oa, (10, 25)) isa OneHotArray
@test reshape(oa, 10, 25) isa OneHotLike
@test reshape(oa, 10, :) isa OneHotLike
@test reshape(oa, :, 25) isa OneHotLike
@test reshape(oa, 50, :) isa OneHotLike
@test reshape(oa, 5, 10, 5) isa OneHotLike
@test reshape(oa, (10, 25)) isa OneHotLike

@testset "w/ cat" begin
r = reshape(oa, 10, :)
@test hcat(r, r) isa OneHotArray
@test vcat(r, r) isa Array{Bool}
end

@testset "w/ argmax" begin
r = reshape(oa, 10, :)
@test argmax(r) == argmax(OneHotMatrix(reshape(oa.indices, :), 10))
@test Flux._fast_argmax(r) == collect(reshape(oa.indices, :))
end
end

@testset "Base.argmax" begin
Expand Down

0 comments on commit 8f79161

Please sign in to comment.