Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamp memory management, and add USM support. #264

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
14ddd21
Add support for USM.
VarLad Jan 21, 2025
43f596a
Reorganize.
maleadt Jan 21, 2025
c11ccd0
Re-use existing wrappers.
maleadt Jan 21, 2025
a615b95
Don't silently ignore INVALID_VALUE on USM memcpy.
maleadt Jan 21, 2025
87aabe0
Simplify.
maleadt Jan 21, 2025
51f365b
Fix USM check.
maleadt Jan 23, 2025
dbe9f1d
More clean-ups.
maleadt Jan 23, 2025
05f97e2
Tighten some pointer conversions.
maleadt Jan 23, 2025
2eaca08
Fix.
maleadt Jan 23, 2025
873a0c3
Fixes.
maleadt Jan 23, 2025
d8e1466
More careful pointer derivations.
maleadt Jan 23, 2025
1f4b55c
Fix freeing.
maleadt Jan 23, 2025
4f5204b
Fix example.
maleadt Jan 23, 2025
575e34b
Rename buf->mem, and fix device switching on memcpy.
maleadt Jan 23, 2025
bad893e
Switch platforms when switching devices.
maleadt Jan 23, 2025
9c69528
More fixes.
maleadt Jan 23, 2025
d3be396
Revert fill to previous implementation.
maleadt Jan 23, 2025
301f1d7
Fix synchronization.
maleadt Jan 23, 2025
5568373
Switch global state to keying everything on the context.
maleadt Jan 23, 2025
72f295a
More simplifications.
maleadt Jan 23, 2025
8a4ddab
Remove unused flag.
maleadt Jan 23, 2025
74ce2ba
More simplifications.
maleadt Jan 23, 2025
ecff68f
Remove unused buffer type.
maleadt Jan 23, 2025
25a9e6e
Avoid some unneeded TLS clears.
maleadt Jan 23, 2025
8f25095
Fix pointer passing.
maleadt Jan 24, 2025
7596fbf
Improve memory object tracking with at-opencl.
maleadt Jan 24, 2025
55fe4cd
Rely less on global state when it can be derived.
maleadt Jan 24, 2025
308e4f4
Accurately keep track of memory passed to clcall.
maleadt Jan 24, 2025
132d260
Don't configure null pointers.
maleadt Jan 24, 2025
4cec57f
Don't iterate devices on free.
maleadt Jan 24, 2025
1c7c11e
Simplify free.
maleadt Jan 24, 2025
4be457c
Show memory back-end in versioninfo.
maleadt Jan 24, 2025
a17c34b
Avoid atexit errors by at least checking if the queue is still valid.
maleadt Jan 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Adapt = "4"
GPUArrays = "11.2.1"
GPUCompiler = "0.27, 1"
KernelAbstractions = "0.9.1"
KernelAbstractions = "0.9.2"
LLVM = "9.1"
LinearAlgebra = "1"
OpenCL_jll = "=2024.5.8"
Expand Down
6 changes: 4 additions & 2 deletions lib/cl/CL.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
module cl

using Printf

include("pointer.jl")
include("api.jl")

# OpenCL wrapper objects are expected to have an `id` field containing a handle pointer
Expand All @@ -15,9 +18,8 @@ include("device.jl")
include("context.jl")
include("cmdqueue.jl")
include("event.jl")
include("memory.jl")
include("memory/memory.jl")
include("buffer.jl")
include("svm.jl")
include("program.jl")
include("kernel.jl")

Expand Down
16 changes: 16 additions & 0 deletions lib/cl/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,22 @@ function retry_reclaim(f, isfailed)
ret
end

macro ext_ccall(ex)
# decode the expression
@assert Meta.isexpr(ex, :(::))
call, ret = ex.args
@assert Meta.isexpr(call, :call)
target, argexprs... = call.args
@assert Meta.isexpr(target, :(.))
_, fn = target.args

@gensym fptr
esc(quote
$fptr = $clGetExtensionFunctionAddressForPlatform(platform(), $fn)
@ccall $(Expr(:($), fptr))($(argexprs...))::$ret
end)
end

include("libopencl.jl")

@static if Sys.iswindows()
Expand Down
75 changes: 74 additions & 1 deletion lib/cl/buffer.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,79 @@
# OpenCL Memory Object

abstract type AbstractMemoryObject <: CLObject end

#This should be implemented by all subtypes
# type MemoryType <: AbstractMemoryObject
# id::cl_mem
# ...
# end

# for passing buffers to OpenCL APIs: use the underlying handle
Base.unsafe_convert(::Type{cl_mem}, mem::AbstractMemoryObject) = mem.id

# for passing buffers to kernels: keep the buffer, it's handled by `cl.set_arg!`
Base.unsafe_convert(::Type{<:Ptr}, mem::AbstractMemoryObject) = mem

Base.sizeof(mem::AbstractMemoryObject) = mem.size

context(mem::AbstractMemoryObject) = mem.context

function Base.getproperty(mem::AbstractMemoryObject, s::Symbol)
if s == :context
param = Ref{cl_context}()
clGetMemObjectInfo(mem, CL_MEM_CONTEXT, sizeof(cl_context), param, C_NULL)
return Context(param[], retain = true)
elseif s == :mem_type
result = Ref{cl_mem_object_type}()
clGetMemObjectInfo(mem, CL_MEM_TYPE, sizeof(cl_mem_object_type), result, C_NULL)
return result[]
elseif s == :mem_flags
result = Ref{cl_mem_flags}()
clGetMemObjectInfo(mem, CL_MEM_FLAGS, sizeof(cl_mem_flags), result, C_NULL)
mf = result[]
flags = Symbol[]
if (mf & CL_MEM_READ_WRITE) != 0
push!(flags, :rw)
end
if (mf & CL_MEM_WRITE_ONLY) != 0
push!(flags, :w)
end
if (mf & CL_MEM_READ_ONLY) != 0
push!(flags, :r)
end
if (mf & CL_MEM_USE_HOST_PTR) != 0
push!(flags, :use)
end
if (mf & CL_MEM_ALLOC_HOST_PTR) != 0
push!(flags, :alloc)
end
if (mf & CL_MEM_COPY_HOST_PTR) != 0
push!(flags, :copy)
end
return tuple(flags...)
elseif s == :size
result = Ref{Csize_t}()
clGetMemObjectInfo(mem, CL_MEM_SIZE, sizeof(Csize_t), result, C_NULL)
return result[]
elseif s == :reference_count
result = Ref{Cuint}()
clGetMemObjectInfo(mem, CL_MEM_REFERENCE_COUNT, sizeof(Cuint), result, C_NULL)
return Int(result[])
elseif s == :map_count
result = Ref{Cuint}()
clGetMemObjectInfo(mem, CL_MEM_MAP_COUNT, sizeof(Cuint), result, C_NULL)
return Int(result[])
else
return getfield(mem, s)
end
end

#TODO: enqueue_migrate_mem_objects(queue, mem_objects, flags=0, wait_for=None)
#TODO: enqueue_migrate_mem_objects_ext(queue, mem_objects, flags=0, wait_for=None)

# OpenCL.Buffer

mutable struct Buffer{T} <: AbstractMemory
mutable struct Buffer{T} <: AbstractMemoryObject
const id::cl_mem
const len::Int

Expand Down
53 changes: 53 additions & 0 deletions lib/cl/device.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,59 @@ function exec_capabilities(d::Device)
)
end

function usm_supported(d::Device)
"cl_intel_unified_shared_memory" in d.extensions || return false
return true
end

function usm_capabilities(d::Device)
usm_supported(d) || throw(ArgumentError("Unified Shared Memory not supported on this device"))

function check_capability_bits(mask::cl_device_unified_shared_memory_capabilities_intel)
(;
access = mask & CL_UNIFIED_SHARED_MEMORY_ACCESS_INTEL != 0,
atomic_access = mask & CL_UNIFIED_SHARED_MEMORY_ATOMIC_ACCESS_INTEL != 0,
concurrent_access = mask & CL_UNIFIED_SHARED_MEMORY_CONCURRENT_ACCESS_INTEL != 0,
concurrent_atomic_access = mask & CL_UNIFIED_SHARED_MEMORY_CONCURRENT_ATOMIC_ACCESS_INTEL != 0,
)
end

host = Ref{cl_device_unified_shared_memory_capabilities_intel}()
device = Ref{cl_device_unified_shared_memory_capabilities_intel}()
single_device = Ref{cl_device_unified_shared_memory_capabilities_intel}()
shared = Ref{cl_device_unified_shared_memory_capabilities_intel}()
cross_device = Ref{cl_device_unified_shared_memory_capabilities_intel}()

clGetDeviceInfo(
d, CL_DEVICE_HOST_MEM_CAPABILITIES_INTEL,
sizeof(cl_device_unified_shared_memory_capabilities_intel), host, C_NULL
)
clGetDeviceInfo(
d, CL_DEVICE_DEVICE_MEM_CAPABILITIES_INTEL,
sizeof(cl_device_unified_shared_memory_capabilities_intel), device, C_NULL
)
clGetDeviceInfo(
d, CL_DEVICE_SINGLE_DEVICE_SHARED_MEM_CAPABILITIES_INTEL,
sizeof(cl_device_unified_shared_memory_capabilities_intel), single_device, C_NULL
)
clGetDeviceInfo(
d, CL_DEVICE_SHARED_SYSTEM_MEM_CAPABILITIES_INTEL,
sizeof(cl_device_unified_shared_memory_capabilities_intel), shared, C_NULL
)
clGetDeviceInfo(
d, CL_DEVICE_CROSS_DEVICE_SHARED_MEM_CAPABILITIES_INTEL,
sizeof(cl_device_unified_shared_memory_capabilities_intel), cross_device, C_NULL
)

return (;
host = check_capability_bits(host[]),
device = check_capability_bits(device[]),
single_device = check_capability_bits(single_device[]),
shared = check_capability_bits(shared[]),
cross_device = check_capability_bits(cross_device[]),
)
end

function svm_capabilities(d::Device)
result = Ref{cl_device_svm_capabilities}()
clGetDeviceInfo(d, CL_DEVICE_SVM_CAPABILITIES,
Expand Down
45 changes: 29 additions & 16 deletions lib/cl/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,47 +51,55 @@ Base.length(l::LocalMem{T}) where {T} = Int(l.nbytes ÷ sizeof(T))

# preserve the LocalMem; it will be handled by set_arg!
# XXX: do we want set_arg!(C_NULL::Ptr) to just call clSetKernelArg?
Base.unsafe_convert(::Type{Ptr{T}}, l::LocalMem{T}) where {T} = l
Base.unsafe_convert(::Type{CLPtr{T}}, l::LocalMem{T}) where {T} = l

function set_arg!(k::Kernel, idx::Integer, arg::Nothing)
@assert idx > 0
clSetKernelArg(k, cl_uint(idx-1), sizeof(cl_mem), C_NULL)
return k
end

# SVMBuffers
# raw memory
## when passing using `cl.call`
function set_arg!(k::Kernel, idx::Integer, arg::SVMBuffer)
clSetKernelArgSVMPointer(k, cl_uint(idx-1), arg.ptr)
function set_arg!(k::Kernel, idx::Integer, arg::AbstractMemory)
if cl.memory_backend() == cl.SVMBackend()
clSetKernelArgSVMPointer(k, cl_uint(idx - 1), arg.ptr)
else
clSetKernelArgMemPointerINTEL(k, cl_uint(idx - 1), arg.ptr)
end
return k
end
## when passing with `clcall`, which has pre-converted the buffer
function set_arg!(k::Kernel, idx::Integer, arg::Union{Ptr,Core.LLVMPtr})
function set_arg!(k::Kernel, idx::Integer, arg::CLPtr{T}) where {T}
arg = reinterpret(Ptr{Cvoid}, arg)
if arg != C_NULL
# XXX: this assumes that the receiving argument is pointer-typed, which is not the
# case with Julia's `Ptr` ABI. Instead, one should reinterpret the pointer as a
# `Core.LLVMPtr`, which _is_ pointer-valued. We retain this handling for `Ptr`
# for users passing pointers to OpenCL C, and because `Ptr` is pointer-valued
# starting with Julia 1.12.
clSetKernelArgSVMPointer(k, cl_uint(idx-1), arg)
if cl.memory_backend() == cl.SVMBackend()
clSetKernelArgSVMPointer(k, cl_uint(idx - 1), arg)
else
clSetKernelArgMemPointerINTEL(k, cl_uint(idx - 1), arg)
end
end
return k
end

# regular buffers
function set_arg!(k::Kernel, idx::Integer, arg::AbstractMemory)
# memory objects
function set_arg!(k::Kernel, idx::Integer, arg::AbstractMemoryObject)
arg_boxed = Ref(arg.id)
clSetKernelArg(k, cl_uint(idx-1), sizeof(cl_mem), arg_boxed)
return k
end

function set_arg!(k::Kernel, idx::Integer, arg::LocalMem)
clSetKernelArg(k, cl_uint(idx-1), arg.nbytes, C_NULL)
clSetKernelArg(k, cl_uint(idx - 1), arg.nbytes, C_NULL)
return k
end

function set_arg!(k::Kernel, idx::Integer, arg::T) where T
function set_arg!(k::Kernel, idx::Integer, arg::T) where {T}
ref = Ref(arg)
tsize = sizeof(ref)
err = unchecked_clSetKernelArg(k, cl_uint(idx - 1), tsize, ref)
Expand Down Expand Up @@ -175,13 +183,18 @@ function enqueue_kernel(k::Kernel, global_work_size, local_work_size=nothing;
return Event(ret_event[], retain=false)
end

function call(k::Kernel, args...; global_size=(1,), local_size=nothing,
global_work_offset=nothing, wait_on::Vector{Event}=Event[],
svm_pointers::Vector{Ptr{Cvoid}}=Ptr{Cvoid}[])
function call(
k::Kernel, args...; global_size = (1,), local_size = nothing,
global_work_offset = nothing, wait_on::Vector{Event} = Event[],
pointers::Vector{CLPtr} = CLPtr[]
)
set_args!(k, args...)
if !isempty(svm_pointers)
clSetKernelExecInfo(k, CL_KERNEL_EXEC_INFO_SVM_PTRS,
sizeof(svm_pointers), svm_pointers)
flag = cl.memory_backend() == cl.SVMBackend() ? CL_KERNEL_EXEC_INFO_SVM_PTRS : CL_KERNEL_EXEC_INFO_USM_PTRS_INTEL
if !isempty(pointers)
clSetKernelExecInfo(
k, flag,
sizeof(pointers), pointers
)
end
enqueue_kernel(k, global_size, local_size; global_work_offset, wait_on)
end
Expand Down
Loading
Loading