Skip to content

Commit

Permalink
all tests but zygote pass on cuda devices
Browse files Browse the repository at this point in the history
  • Loading branch information
leios committed Oct 24, 2024
1 parent a0324bc commit 7d90d7f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 34 deletions.
43 changes: 13 additions & 30 deletions ext/MollyCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Molly
using CUDA
using ChainRulesCore
using Atomix
using KernelAbstractions

CUDA.Const(nl::Molly.NoNeighborList) = nl

Expand Down Expand Up @@ -38,12 +39,12 @@ function cuda_threads_blocks_specific(n_inters)
return n_threads_gpu, n_blocks
end

function pairwise_force_gpu(coords::CuArray{SVector{D, C}}, atoms, boundary,
function Molly.pairwise_force_gpu(coords::CuArray{SVector{D, C}}, atoms, boundary,
pairwise_inters, nbs, force_units, ::Val{T}) where {D, C, T}
fs_mat = CUDA.zeros(T, D, length(atoms))

if typeof(nbs) == NoNeighborList
kernel = @cuda launch=false pairwise_force_kernel_nonl!(
if typeof(nbs) == Molly.NoNeighborList
kernel = @cuda launch=false cuda_pairwise_force_kernel_nonl!(
fs_mat, coords, atoms, boundary, pairwise_inters, Val(D), Val(force_units))
conf = launch_configuration(kernel.fun)
threads_basic = parse(Int, get(ENV, "MOLLY_GPUNTHREADS_PAIRWISE", "512"))
Expand All @@ -54,33 +55,15 @@ function pairwise_force_gpu(coords::CuArray{SVector{D, C}}, atoms, boundary,
kernel(fs_mat, coords, atoms, boundary, pairwise_inters, Val(D), Val(force_units);
threads=nthreads, blocks=(n_blocks_i, n_blocks_j))
else
n_threads_gpu, n_blocks = cuda_threads_blocks_pairwise(length(nbs))
CUDA.@sync @cuda threads=n_threads_gpu blocks=n_blocks pairwise_force_kernel_nl!(
fs_mat, coords, atoms, boundary, pairwise_inters, nbs, Val(D), Val(force_units))
backend = get_backend(coords)
n_threads_gpu = Molly.gpu_threads_blocks_pairwise(length(nbs))
kernel! = Molly.pairwise_force_kernel!(backend, n_threads_gpu)
kernel!(fs_mat, coords, atoms, boundary, pairwise_inters, nbs,
Val(D), Val(force_units), ndrange = length(nbs))
end
return fs_mat
end

function pairwise_force_kernel_nl!(forces, coords_var, atoms_var, boundary, inters,
neighbors_var, ::Val{D}, ::Val{F}) where {D, F}
coords = CUDA.Const(coords_var)
atoms = CUDA.Const(atoms_var)
neighbors = CUDA.Const(neighbors_var)

inter_i = (blockIdx().x - 1) * blockDim().x + threadIdx().x

@inbounds if inter_i <= length(neighbors)
i, j, special = neighbors[inter_i]
f = sum_pairwise_forces(inters, coords[i], coords[j], atoms[i], atoms[j], boundary, special, Val(F))
for dim in 1:D
fval = ustrip(f[dim])
Atomix.@atomic :monotonic forces[dim, i] += -fval
Atomix.@atomic :monotonic forces[dim, j] += fval
end
end
return nothing
end

#=
**The No-neighborlist pairwise force summation kernel**: This kernel calculates all the pairwise forces in the system of
`n_atoms` atoms, this is done by dividing the complete matrix of `n_atoms`×`n_atoms` interactions into small tiles. Most
Expand Down Expand Up @@ -121,7 +104,7 @@ That's why the calculations are done in the following order:
h | 1 2 3 4 5 6
```
=#
function pairwise_force_kernel_nonl!(forces::CuArray{T}, coords_var, atoms_var, boundary, inters,
function cuda_pairwise_force_kernel_nonl!(forces::AbstractArray{T}, coords_var, atoms_var, boundary, inters,
::Val{D}, ::Val{F}) where {T, D, F}
coords = CUDA.Const(coords_var)
atoms = CUDA.Const(atoms_var)
Expand All @@ -147,7 +130,7 @@ function pairwise_force_kernel_nonl!(forces::CuArray{T}, coords_var, atoms_var,
j = j_0_tile + del_j
if i != j
atom_j, coord_j = atoms[j], coords[j]
f = sum_pairwise_forces(inters, coord_i, coord_j, atom_i, atom_j, boundary, false, Val(F))
f = Molly.sum_pairwise_forces(inters, coord_i, coord_j, atom_i, atom_j, boundary, false, Val(F))
for dim in 1:D
forces_shmem[dim, tidx] += -ustrip(f[dim])
end
Expand All @@ -171,9 +154,9 @@ function pairwise_force_kernel_nonl!(forces::CuArray{T}, coords_var, atoms_var,
@inbounds for _ in 1:tilesteps
sync_warp()
atom_j = atoms[j]
f = sum_pairwise_forces(inters, coord_i, coord_j, atom_i, atom_j, boundary, false, Val(F))
f = Molly.sum_pairwise_forces(inters, coord_i, coord_j, atom_i, atom_j, boundary, false, Val(F))
for dim in 1:D
forces_shmem[dim, tidx] += -ustrip(f[dim])
forces_shmem[dim, tidx] += -Molly.ustrip(f[dim])
end
@shfl_multiple_sync(FULL_MASK, laneid() + 1, warpsize(), j, coord_j)
end
Expand Down
2 changes: 1 addition & 1 deletion src/chain_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ function ChainRulesCore.rrule(::typeof(pairwise_force_gpu), coords::AbstractArra

function pairwise_force_gpu_pullback(d_fs_mat)
backend = get_backend(coords)
ArrayType = find_array_type(coords)
ArrayType = get_array_type(coords)
n_atoms = length(atoms)
z = zero(T)
fs_mat = KernelAbstractions.zeros(backend, T, D, n_atoms)
Expand Down
2 changes: 1 addition & 1 deletion src/interactions/implicit_solvent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ function ImplicitSolventGBN2(atoms::AbstractArray{Atom{T, M, D, E}},
end

if isa(atoms, AbstractGPUArray)
ArrayType = fine_array_type(atoms)
ArrayType = get_array_type(atoms)
or = ArrayType(offset_radii)
sor = ArrayType(scaled_offset_radii)
is, js = ArrayType(inds_i), ArrayType(inds_j)
Expand Down
4 changes: 2 additions & 2 deletions src/kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ end
f = sum_pairwise_forces(inters, coords[i], coords[j], atoms[i], atoms[j], boundary, special, Val(F))
for dim in 1:D
fval = ustrip(f[dim])
Atomix.@atomic forces[dim, i] = forces[dim, i] - fval
Atomix.@atomic forces[dim, j] = forces[dim, j] + fval
Atomix.@atomic :monotonic forces[dim, i] -= fval
Atomix.@atomic :monotonic forces[dim, j] += fval
end
end
end
Expand Down

0 comments on commit 7d90d7f

Please sign in to comment.