Skip to content

Commit

Permalink
Add GPUNumber
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th committed Jul 6, 2024
1 parent 80b8226 commit cf88dfe
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 12 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "10.2.3"

[deps]
AbstractNumbers = "85c772de-338a-5e7f-b815-41e76c26ac1f"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Expand Down
3 changes: 3 additions & 0 deletions src/GPUArrays.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module GPUArrays

import AbstractNumbers as AN

using Serialization
using Random
using LinearAlgebra
Expand All @@ -25,6 +27,7 @@ include("device/synchronization.jl")
# host abstractions
include("host/abstractarray.jl")
include("host/construction.jl")
include("host/gpunumber.jl")
## integrations and specialized methods
include("host/base.jl")
include("host/indexing.jl")
Expand Down
30 changes: 30 additions & 0 deletions src/host/gpunumber.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Custom GPU-compatible `Number` interface.
struct GPUNumber{T <: AbstractGPUArray} <: AN.AbstractNumber{T}
val::T

function GPUNumber(val::T) where T <: AbstractGPUArray
length(val) != 1 && error(
"`GPUNumber` accepts only 1-element GPU arrays, " *
"instead `$(length(val))`-element array was given.")
new{T}(val)
end
end

AN.number(g::GPUNumber) = @allowscalar g.val[]

maybe_number(g::GPUNumber) = AN.number(g)
maybe_number(g) = g

number_type(::GPUNumber{T}) where T = eltype(T)

# When operations involve other `::Number` types,
# do not convert back to `GPUNumber`.
AN.like(::Type{<: GPUNumber}, x) = x

# When broadcasting, just pass the array itself.
Base.broadcastable(g::GPUNumber) = g.val

# Overload to avoid copies.
Base.one(g::GPUNumber) = one(number_type(g))
Base.zero(g::GPUNumber) = zero(number_type(g))
Base.identity(g::GPUNumber) = g
7 changes: 4 additions & 3 deletions src/host/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ function Base.findfirst(f::Function, A::AnyGPUArray)
end

res = mapreduce((x, y)->(f(x), y), reduction, A, indices;
init = (false, dummy_index))
init = (false, dummy_index)) |> AN.number
if res[1]
# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
Expand All @@ -230,14 +230,15 @@ function findminmax(binop, A::AnyGPUArray; init, dims)
end

if dims == Colon()
res = mapreduce(tuple, reduction, A, indices; init = (init, dummy_index))
res = mapreduce(tuple, reduction, A, indices;
init = (init, dummy_index)) |> AN.number

# out of consistency with Base.findarray, return a CartesianIndex
# when the input is a multidimensional array
return (res[1], ndims(A) == 1 ? res[2] : CartesianIndices(A)[res[2]])
else
res = mapreduce(tuple, reduction, A, indices;
init = (init, dummy_index), dims=dims)
init = (init, dummy_index), dims=dims) |> maybe_number
vals = map(x->x[1], res)
inds = map(x->ndims(A) == 1 ? x[2] : CartesianIndices(A)[x[2]], res)
return (vals, inds)
Expand Down
20 changes: 11 additions & 9 deletions src/host/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,20 +68,20 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
end

if dims === Colon()
@allowscalar R[]
GPUNumber(R)
else
R
end
end

Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A)
Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A)
Base.any(A::AnyGPUArray{Bool}) = mapreduce(identity, |, A) |> AN.number
Base.all(A::AnyGPUArray{Bool}) = mapreduce(identity, &, A) |> AN.number

Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A)
Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A)
Base.any(f::Function, A::AnyGPUArray) = mapreduce(f, |, A) |> AN.number
Base.all(f::Function, A::AnyGPUArray) = mapreduce(f, &, A) |> AN.number

Base.count(pred::Function, A::AnyGPUArray; dims=:, init=0) =
mapreduce(pred, Base.add_sum, A; init=init, dims=dims)
mapreduce(pred, Base.add_sum, A; init=init, dims=dims) |> maybe_number

# avoid calling into `initarray!`
for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
Expand All @@ -94,7 +94,8 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)),
end
end

LinearAlgebra.ishermitian(A::AbstractGPUMatrix) = mapreduce(==, &, A, adjoint(A))
LinearAlgebra.ishermitian(A::AbstractGPUMatrix) =
mapreduce(==, &, A, adjoint(A)) |> AN.number


# comparisons
Expand All @@ -105,7 +106,7 @@ function Base.isequal(A::AnyGPUArray, B::AnyGPUArray)
if axes(A) != axes(B)
return false
end
mapreduce(isequal, &, A, B; init=true)
mapreduce(isequal, &, A, B; init=true) |> AN.number
end

# returns `missing` when missing values are involved
Expand All @@ -129,6 +130,7 @@ function Base.:(==)(A::AnyGPUArray, B::AnyGPUArray)
(; is_missing=false, is_equal=a.is_equal & b.is_equal)
end
end
res = mapreduce(mapper, reducer, A, B; init=(; is_missing=false, is_equal=true))
res = mapreduce(mapper, reducer, A, B;
init=(; is_missing=false, is_equal=true)) |> AN.number
res.is_missing ? missing : res.is_equal
end
3 changes: 3 additions & 0 deletions test/testsuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ function test_result(as::NTuple{N,Any}, bs::NTuple{N,Any}; kwargs...) where {N}
test_result(a, b; kwargs...)
end
end
# Special case for `extrema` accross all dims.
test_result(as::NTuple{N,Any}, bs::GPUArrays.GPUNumber; kwargs...) where {N} =
test_result(as, GPUArrays.maybe_number(bs))

function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...)
# copy on the CPU, adapt on the GPU, but keep Ref's
Expand Down

0 comments on commit cf88dfe

Please sign in to comment.