From 32e06c83ecffe6a382c58f728b72a468b4f8d1e1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 27 Dec 2022 16:15:53 -0500 Subject: [PATCH] Fast path `onehotbatch(::Vector{Int}, ::UnitRange)` (#27) * add a fast path * add an error check * fixup, add tests * fix 1.6 --- Project.toml | 2 +- src/onehot.jl | 10 ++++++++++ test/gpu.jl | 10 ++++++++++ test/onehot.jl | 6 ++++++ 4 files changed, 27 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 29a5818..94da4fc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "OneHotArrays" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.1" +version = "0.2.2" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/onehot.jl b/src/onehot.jl index c225fc4..ca2efa5 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -100,6 +100,16 @@ function _onehotbatch(data, labels, default) return OneHotArray(indices, length(labels)) end +function onehotbatch(data::AbstractArray{<:Integer}, labels::AbstractUnitRange{<:Integer}) + # lo, hi = extrema(data) # fails on Julia 1.6 + lo, hi = minimum(data), maximum(data) + lo < first(labels) && error("Value $lo not found in labels") + hi > last(labels) && error("Value $hi not found in labels") + offset = 1 - first(labels) + indices = UInt32.(data .+ offset) + return OneHotArray(indices, length(labels)) +end + """ onecold(y::AbstractArray, labels = 1:size(y,1)) diff --git a/test/gpu.jl b/test/gpu.jl index 13c208c..cd04815 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -26,6 +26,16 @@ end @test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote? end +@testset "onehotbatch(::CuArray, ::UnitRange)" begin + y1 = onehotbatch([1, 3, 0, 2], 0:9) |> cu + y2 = onehotbatch([1, 3, 0, 2] |> cu, 0:9) + @test y1.indices == y2.indices + @test_broken y1 == y2 + + @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, 1:10) + @test_throws Exception onehotbatch([1, 3, 0, 2] |> cu, -2:2) +end + @testset "onecold gpu" begin y = onehotbatch(ones(3), 1:10) |> cu; l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'] diff --git a/test/onehot.jl b/test/onehot.jl index 0628230..fffac19 100644 --- a/test/onehot.jl +++ b/test/onehot.jl @@ -27,6 +27,12 @@ @test onecold(onehot(-0.0, floats)) == 2 # as it uses isequal @test onecold(onehot(Inf, floats)) == 5 + # UnitRange fast path + @test onehotbatch([1,3,0,4], 0:4) == onehotbatch([1,3,0,4], Tuple(0:4)) + @test onehotbatch([2 3 7 4], 2:7) == onehotbatch([2 3 7 4], Tuple(2:7)) + @test_throws Exception onehotbatch([2, -1], 0:4) + @test_throws Exception onehotbatch([2, 5], 0:4) + # inferrabiltiy tests @test @inferred(onehot(20, 10:10:30)) == [false, true, false] @test @inferred(onehot(40, (10,20,30), 20)) == [false, true, false]