Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hip][cuda] Merged pending_queue_actions implementations. #18220

Merged
merged 9 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ iree_runtime_cc_library(
"nccl_channel.h",
"nop_executable_cache.c",
"nop_executable_cache.h",
"pending_queue_actions.c",
"pending_queue_actions.h",
"pipeline_layout.c",
"pipeline_layout.h",
"stream_command_buffer.c",
Expand Down Expand Up @@ -66,6 +64,7 @@ iree_runtime_cc_library(
"//runtime/src/iree/hal",
"//runtime/src/iree/hal/utils:collective_batch",
"//runtime/src/iree/hal/utils:deferred_command_buffer",
"//runtime/src/iree/hal/utils:deferred_work_queue",
"//runtime/src/iree/hal/utils:file_transfer",
"//runtime/src/iree/hal/utils:memory_file",
"//runtime/src/iree/hal/utils:resource_set",
Expand Down
3 changes: 1 addition & 2 deletions runtime/src/iree/hal/drivers/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ iree_cc_library(
"nccl_channel.h"
"nop_executable_cache.c"
"nop_executable_cache.h"
"pending_queue_actions.c"
"pending_queue_actions.h"
"pipeline_layout.c"
"pipeline_layout.h"
"stream_command_buffer.c"
Expand All @@ -63,6 +61,7 @@ iree_cc_library(
iree::hal
iree::hal::utils::collective_batch
iree::hal::utils::deferred_command_buffer
iree::hal::utils::deferred_work_queue
iree::hal::utils::file_transfer
iree::hal::utils::memory_file
iree::hal::utils::resource_set
Expand Down
253 changes: 234 additions & 19 deletions runtime/src/iree/hal/drivers/cuda/cuda_device.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
#include "iree/hal/drivers/cuda/nccl_channel.h"
#include "iree/hal/drivers/cuda/nccl_dynamic_symbols.h"
#include "iree/hal/drivers/cuda/nop_executable_cache.h"
#include "iree/hal/drivers/cuda/pending_queue_actions.h"
#include "iree/hal/drivers/cuda/pipeline_layout.h"
#include "iree/hal/drivers/cuda/stream_command_buffer.h"
#include "iree/hal/drivers/cuda/timepoint_pool.h"
#include "iree/hal/drivers/cuda/tracing.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/deferred_work_queue.h"
#include "iree/hal/utils/file_transfer.h"
#include "iree/hal/utils/memory_file.h"

Expand Down Expand Up @@ -76,7 +76,7 @@ typedef struct iree_hal_cuda_device_t {
// are met. It buffers submissions and allocations internally before they
// are ready. This queue couples with HAL semaphores backed by iree_event_t
// and CUevent objects.
iree_hal_cuda_pending_queue_actions_t* pending_queue_actions;
iree_hal_deferred_work_queue_t* work_queue;

// Device memory pools and allocators.
bool supports_memory_pools;
Expand All @@ -88,6 +88,176 @@ typedef struct iree_hal_cuda_device_t {
} iree_hal_cuda_device_t;

static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable;
static const iree_hal_deferred_work_queue_device_interface_vtable_t
iree_hal_cuda_deferred_work_queue_device_interface_vtable;

// We put a CUEvent into a iree_hal_deferred_work_queue_native_event_t.
static_assert(sizeof(CUevent) <=
sizeof(iree_hal_deferred_work_queue_native_event_t),
"Unexpected event size");
typedef struct iree_hal_cuda_deferred_work_queue_device_interface_t {
iree_hal_deferred_work_queue_device_interface_t base;
iree_hal_device_t* device;
CUdevice cu_device;
CUcontext cu_context;
CUstream dispatch_cu_stream;
iree_allocator_t host_allocator;
const iree_hal_cuda_dynamic_symbols_t* cuda_symbols;
} iree_hal_cuda_deferred_work_queue_device_interface_t;

static void iree_hal_cuda_deferred_work_queue_device_interface_destroy(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
iree_allocator_free(device_interface->host_allocator, device_interface);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_bind_to_thread(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuCtxSetCurrent(device_interface->cu_context),
"cuCtxSetCurrent");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_wait_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuStreamWaitEvent(device_interface->dispatch_cu_stream, (CUevent)event,
CU_EVENT_WAIT_DEFAULT),
"cuStreamWaitEvent");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_create_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t* out_event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuEventCreate((CUevent*)out_event, CU_EVENT_WAIT_DEFAULT),
"cuEventCreate");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_record_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuEventRecord((CUevent)event, device_interface->dispatch_cu_stream),
"cuEventCreate");
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_synchronize_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuEventSynchronize((CUevent)event));
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_destroy_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_native_event_t event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols,
cuEventDestroy((CUevent)event));
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
struct iree_hal_semaphore_t* semaphore, uint64_t value,
iree_hal_deferred_work_queue_native_event_t* out_event) {
return iree_hal_cuda_event_semaphore_acquire_timepoint_device_signal(
semaphore, value, (CUevent*)out_event);
}

static bool
iree_hal_cuda_deferred_work_queue_device_interface_acquire_host_wait_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
struct iree_hal_semaphore_t* semaphore, uint64_t value,
iree_hal_deferred_work_queue_host_device_event_t* out_event) {
return iree_hal_cuda_semaphore_acquire_event_host_wait(
semaphore, value, (iree_hal_cuda_event_t**)out_event);
}

static void
iree_hal_cuda_deferred_work_queue_device_interface_release_wait_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_host_device_event_t wait_event) {
iree_hal_cuda_event_release(wait_event);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_device_wait_on_host_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_host_device_event_t wait_event) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuStreamWaitEvent(
device_interface->dispatch_cu_stream,
iree_hal_cuda_event_handle((iree_hal_cuda_event_t*)wait_event), 0),
"cuStreamWaitEvent");
}

static void*
iree_hal_cuda_deferred_work_queue_device_interface_native_event_from_wait_event(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_deferred_work_queue_host_device_event_t event) {
return iree_hal_cuda_event_handle((iree_hal_cuda_event_t*)event);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_create_stream_command_buffer(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t categories,
iree_hal_command_buffer_t** out) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
return iree_hal_cuda_device_create_stream_command_buffer(
device_interface->device, mode, categories, 0, out);
}

static iree_status_t
iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer(
iree_hal_deferred_work_queue_device_interface_t* base_device_interface,
iree_hal_command_buffer_t* command_buffer) {
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface =
(iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface);
iree_status_t status = iree_ok_status();
if (iree_hal_cuda_stream_command_buffer_isa(command_buffer)) {
// Stream command buffer so nothing to do but notify it was submitted.
iree_hal_cuda_stream_notify_submitted_commands(command_buffer);
} else {
CUgraphExec exec =
iree_hal_cuda_graph_command_buffer_handle(command_buffer);
status = IREE_CURESULT_TO_STATUS(
device_interface->cuda_symbols,
cuGraphLaunch(exec, device_interface->dispatch_cu_stream));
if (IREE_LIKELY(iree_status_is_ok(status))) {
iree_hal_cuda_graph_tracing_notify_submitted_commands(command_buffer);
}
}
return status;
}

static iree_hal_cuda_device_t* iree_hal_cuda_device_cast(
iree_hal_device_t* base_value) {
Expand Down Expand Up @@ -152,9 +322,27 @@ static iree_status_t iree_hal_cuda_device_create_internal(
device->dispatch_cu_stream = dispatch_stream;
device->host_allocator = host_allocator;

iree_status_t status = iree_hal_cuda_pending_queue_actions_create(
cuda_symbols, cu_device, context, &device->block_pool, host_allocator,
&device->pending_queue_actions);
iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface;
iree_status_t status = iree_allocator_malloc(
host_allocator,
sizeof(iree_hal_cuda_deferred_work_queue_device_interface_t),
(void**)&device_interface);
if (IREE_UNLIKELY(!iree_status_is_ok(status))) {
iree_hal_device_release((iree_hal_device_t*)device);
return status;
}
device_interface->base.vtable =
&iree_hal_cuda_deferred_work_queue_device_interface_vtable;
device_interface->cu_context = context;
device_interface->cuda_symbols = cuda_symbols;
device_interface->cu_device = cu_device;
device_interface->device = (iree_hal_device_t*)device;
device_interface->dispatch_cu_stream = dispatch_stream;
device_interface->host_allocator = host_allocator;

status = iree_hal_deferred_work_queue_create(
(iree_hal_deferred_work_queue_device_interface_t*)device_interface,
&device->block_pool, host_allocator, &device->work_queue);

// Enable tracing for the (currently only) stream - no-op if disabled.
if (iree_status_is_ok(status) && device->params.stream_tracing) {
Expand Down Expand Up @@ -297,8 +485,7 @@ static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) {
IREE_TRACE_ZONE_BEGIN(z0);

// Destroy the pending workload queue.
iree_hal_cuda_pending_queue_actions_destroy(
(iree_hal_resource_t*)device->pending_queue_actions);
iree_hal_deferred_work_queue_destroy(device->work_queue);

// There should be no more buffers live that use the allocator.
iree_hal_allocator_release(device->device_allocator);
Expand Down Expand Up @@ -620,7 +807,7 @@ static iree_status_t iree_hal_cuda_device_create_semaphore(
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
return iree_hal_cuda_event_semaphore_create(
initial_value, device->cuda_symbols, device->timepoint_pool,
device->pending_queue_actions, device->host_allocator, out_semaphore);
device->work_queue, device->host_allocator, out_semaphore);
}

static iree_hal_semaphore_compatibility_t
Expand Down Expand Up @@ -765,15 +952,13 @@ static iree_status_t iree_hal_cuda_device_queue_execute(
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);

iree_status_t status = iree_hal_cuda_pending_queue_actions_enqueue_execution(
base_device, device->dispatch_cu_stream, device->pending_queue_actions,
iree_hal_cuda_device_collect_tracing_context, device->tracing_context,
wait_semaphore_list, signal_semaphore_list, command_buffer_count,
command_buffers, binding_tables);
iree_status_t status = iree_hal_deferred_work_queue_enque(
device->work_queue, iree_hal_cuda_device_collect_tracing_context,
device->tracing_context, wait_semaphore_list, signal_semaphore_list,
command_buffer_count, command_buffers, binding_tables);
if (iree_status_is_ok(status)) {
// Try to advance the pending workload queue.
status = iree_hal_cuda_pending_queue_actions_issue(
device->pending_queue_actions);
// Try to advance the deferred work queue.
status = iree_hal_deferred_work_queue_issue(device->work_queue);
}

IREE_TRACE_ZONE_END(z0);
Expand All @@ -784,9 +969,8 @@ static iree_status_t iree_hal_cuda_device_queue_flush(
iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) {
iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device);
IREE_TRACE_ZONE_BEGIN(z0);
// Try to advance the pending workload queue.
iree_status_t status =
iree_hal_cuda_pending_queue_actions_issue(device->pending_queue_actions);
// Try to advance the deferred work queue.
iree_status_t status = iree_hal_deferred_work_queue_issue(device->work_queue);
IREE_TRACE_ZONE_END(z0);
return status;
}
Expand Down Expand Up @@ -850,3 +1034,34 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = {
.profiling_flush = iree_hal_cuda_device_profiling_flush,
.profiling_end = iree_hal_cuda_device_profiling_end,
};

static const iree_hal_deferred_work_queue_device_interface_vtable_t
iree_hal_cuda_deferred_work_queue_device_interface_vtable = {
.destroy = iree_hal_cuda_deferred_work_queue_device_interface_destroy,
.bind_to_thread =
iree_hal_cuda_deferred_work_queue_device_interface_bind_to_thread,
.wait_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_wait_native_event,
.create_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_create_native_event,
.record_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_record_native_event,
.synchronize_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_synchronize_native_event,
.destroy_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_destroy_native_event,
.semaphore_acquire_timepoint_device_signal_native_event =
iree_hal_cuda_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event,
.acquire_host_wait_event =
iree_hal_cuda_deferred_work_queue_device_interface_acquire_host_wait_event,
.device_wait_on_host_event =
iree_hal_cuda_deferred_work_queue_device_interface_device_wait_on_host_event,
.release_wait_event =
iree_hal_cuda_deferred_work_queue_device_interface_release_wait_event,
.native_event_from_wait_event =
iree_hal_cuda_deferred_work_queue_device_interface_native_event_from_wait_event,
.create_stream_command_buffer =
iree_hal_cuda_deferred_work_queue_device_interface_create_stream_command_buffer,
.submit_command_buffer =
iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer,
};
Loading
Loading