Skip to content

Commit

Permalink
Add findnz & fix broadcasting (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
pxl-th authored Dec 1, 2024
1 parent 57d8324 commit a52cbbe
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AMDGPU"
uuid = "21141c5a-9bdb-4563-92ae-f87d6854732e"
authors = ["Julian P Samaroo <[email protected]>", "Valentin Churavy <[email protected]>", "Anton Smirnov <[email protected]>"]
version = "1.1.2"
version = "1.1.3"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
13 changes: 12 additions & 1 deletion src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,18 @@ BroadcastStyle(::Type{<:ROCArray{T, N, B}}) where {T, N, B} =
BroadcastStyle(W::Type{<:AnyROCArray{T, N}}) where {T, N} =
ROCArrayStyle{N, buftype(Adapt.unwrap_type(W))}()

# TODO handle broadcast of different buffer types (use unified memory).
# TODO use unified buffer once we support it.
# Broadcast of two different buffers - choose `HIPBuffer`.
BroadcastStyle(
::ROCArrayStyle{N1, B1},
::ROCArrayStyle{N2, B2},
) where {N1,N2,B1,B2} = ROCArrayStyle{max(N1,N2), Mem.HIPBuffer}()

# Different N, same buffer type.
BroadcastStyle(
::ROCArrayStyle{N1, B},
::ROCArrayStyle{N2, B},
) where {N1,N2,B} = ROCArrayStyle{max(N1,N2), B}()

# Allocation of output arrays.
function Base.similar(
Expand Down
19 changes: 19 additions & 0 deletions src/sparse/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,31 @@ Base.eltype(g::ROCSparseMatrix{T}) where T = T

## sparse array interface

SparseArrays.sparsevec(I::ROCArray{Ti}, V::ROCArray{Tv}, n::Integer) where {Ti,Tv} =
ROCSparseVector(I, V, n)

function SparseArrays.findnz(S::T) where {T <: AbstractROCSparseMatrix}
S2 = ROCSparseMatrixCOO(S)
I = S2.rowInd
J = S2.colInd
V = S2.nzVal

# To make it compatible with the SparseArrays.jl version
idxs = sortperm(J)
I = I[idxs]
J = J[idxs]
V = V[idxs]

return (I, J, V)
end

SparseArrays.nnz(g::AbstractROCSparseArray) = g.nnz
SparseArrays.nonzeros(g::AbstractROCSparseArray) = g.nzVal

SparseArrays.nonzeroinds(g::AbstractROCSparseVector) = g.iPtr

SparseArrays.rowvals(g::ROCSparseMatrixCSC) = g.rowVal
SparseArrays.getcolptr(g::ROCSparseMatrixCSC) = g.colPtr

LinearAlgebra.issymmetric(M::Union{ROCSparseMatrixCSC,ROCSparseMatrixCSR}) = false
LinearAlgebra.ishermitian(M::Union{ROCSparseMatrixCSC,ROCSparseMatrixCSR}) = false
Expand Down
8 changes: 8 additions & 0 deletions test/rocarray/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ end
AMDGPU.unsafe_free!(xd2)
@test_throws ArgumentError pointer(xd2)
end

@testset "Broadcasting different buffer types" begin
x = rand(Float32, 4, 16, 16)
xd = unsafe_wrap(ROCArray, pointer(x), size(x))
y = AMDGPU.zeros(Float32, 3, 16, 16)
y .= @view(xd[1:3, :, :])
@test Array(y) @view(x[1:3, :, :])
end
end

@testset "unsafe_free" begin
Expand Down
23 changes: 23 additions & 0 deletions test/rocsparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,26 @@
@test (dD - dA) isa typ
end
end

@testset "SparseArrays.jl" begin
@testset "findnz" begin
n = 35
A = sprand(n, n, 0.2)
d_A = ROCSparseMatrixCSC(A)
@test Array(SparseArrays.getcolptr(d_A)) == SparseArrays.getcolptr(A)

i, j, v = findnz(A)
d_i, d_j, d_v = findnz(d_A)
@test Array(d_i) == i && Array(d_j) == j && Array(d_v) == v

i = unique(sort(rand(1:n, 10)))
vals = rand(length(i))
d_i = ROCArray(i)
d_vals = ROCArray(vals)
v = sparsevec(i, vals, n)
d_v = sparsevec(d_i, d_vals, n)
@test Array(d_v.iPtr) == v.nzind
@test Array(d_v.nzVal) == v.nzval
@test d_v.len == v.n
end
end

2 comments on commit a52cbbe

@pxl-th
Copy link
Member Author

@pxl-th pxl-th commented on a52cbbe Dec 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120494

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.3 -m "<description of version>" a52cbbe5ce8e0a7c672751924b7828cc41e9c26a
git push origin v1.1.3

Please sign in to comment.