diff --git a/runtime/src/iree/hal/drivers/cuda/BUILD.bazel b/runtime/src/iree/hal/drivers/cuda/BUILD.bazel index 89fbe0ae088e..142d0e18bd9e 100644 --- a/runtime/src/iree/hal/drivers/cuda/BUILD.bazel +++ b/runtime/src/iree/hal/drivers/cuda/BUILD.bazel @@ -37,8 +37,6 @@ iree_runtime_cc_library( "nccl_channel.h", "nop_executable_cache.c", "nop_executable_cache.h", - "pending_queue_actions.c", - "pending_queue_actions.h", "pipeline_layout.c", "pipeline_layout.h", "stream_command_buffer.c", @@ -66,6 +64,7 @@ iree_runtime_cc_library( "//runtime/src/iree/hal", "//runtime/src/iree/hal/utils:collective_batch", "//runtime/src/iree/hal/utils:deferred_command_buffer", + "//runtime/src/iree/hal/utils:deferred_work_queue", "//runtime/src/iree/hal/utils:file_transfer", "//runtime/src/iree/hal/utils:memory_file", "//runtime/src/iree/hal/utils:resource_set", diff --git a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt index ebf5386ad7f0..f7c9afd564ab 100644 --- a/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/cuda/CMakeLists.txt @@ -38,8 +38,6 @@ iree_cc_library( "nccl_channel.h" "nop_executable_cache.c" "nop_executable_cache.h" - "pending_queue_actions.c" - "pending_queue_actions.h" "pipeline_layout.c" "pipeline_layout.h" "stream_command_buffer.c" @@ -63,6 +61,7 @@ iree_cc_library( iree::hal iree::hal::utils::collective_batch iree::hal::utils::deferred_command_buffer + iree::hal::utils::deferred_work_queue iree::hal::utils::file_transfer iree::hal::utils::memory_file iree::hal::utils::resource_set diff --git a/runtime/src/iree/hal/drivers/cuda/cuda_device.c b/runtime/src/iree/hal/drivers/cuda/cuda_device.c index 23363e1e2a7e..a53f3818e610 100644 --- a/runtime/src/iree/hal/drivers/cuda/cuda_device.c +++ b/runtime/src/iree/hal/drivers/cuda/cuda_device.c @@ -23,12 +23,12 @@ #include "iree/hal/drivers/cuda/nccl_channel.h" #include "iree/hal/drivers/cuda/nccl_dynamic_symbols.h" #include "iree/hal/drivers/cuda/nop_executable_cache.h" -#include "iree/hal/drivers/cuda/pending_queue_actions.h" #include "iree/hal/drivers/cuda/pipeline_layout.h" #include "iree/hal/drivers/cuda/stream_command_buffer.h" #include "iree/hal/drivers/cuda/timepoint_pool.h" #include "iree/hal/drivers/cuda/tracing.h" #include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/deferred_work_queue.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" @@ -76,7 +76,7 @@ typedef struct iree_hal_cuda_device_t { // are met. It buffers submissions and allocations internally before they // are ready. This queue couples with HAL semaphores backed by iree_event_t // and CUevent objects. - iree_hal_cuda_pending_queue_actions_t* pending_queue_actions; + iree_hal_deferred_work_queue_t* work_queue; // Device memory pools and allocators. bool supports_memory_pools; @@ -88,6 +88,176 @@ typedef struct iree_hal_cuda_device_t { } iree_hal_cuda_device_t; static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable; +static const iree_hal_deferred_work_queue_device_interface_vtable_t + iree_hal_cuda_deferred_work_queue_device_interface_vtable; + +// We put a CUEvent into a iree_hal_deferred_work_queue_native_event_t. +static_assert(sizeof(CUevent) <= + sizeof(iree_hal_deferred_work_queue_native_event_t), + "Unexpected event size"); +typedef struct iree_hal_cuda_deferred_work_queue_device_interface_t { + iree_hal_deferred_work_queue_device_interface_t base; + iree_hal_device_t* device; + CUdevice cu_device; + CUcontext cu_context; + CUstream dispatch_cu_stream; + iree_allocator_t host_allocator; + const iree_hal_cuda_dynamic_symbols_t* cuda_symbols; +} iree_hal_cuda_deferred_work_queue_device_interface_t; + +static void iree_hal_cuda_deferred_work_queue_device_interface_destroy( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + iree_allocator_free(device_interface->host_allocator, device_interface); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_bind_to_thread( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols, + cuCtxSetCurrent(device_interface->cu_context), + "cuCtxSetCurrent"); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_wait_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS( + device_interface->cuda_symbols, + cuStreamWaitEvent(device_interface->dispatch_cu_stream, (CUevent)event, + CU_EVENT_WAIT_DEFAULT), + "cuStreamWaitEvent"); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_create_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t* out_event) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS( + device_interface->cuda_symbols, + cuEventCreate((CUevent*)out_event, CU_EVENT_WAIT_DEFAULT), + "cuEventCreate"); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_record_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS( + device_interface->cuda_symbols, + cuEventRecord((CUevent)event, device_interface->dispatch_cu_stream), + "cuEventCreate"); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_synchronize_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols, + cuEventSynchronize((CUevent)event)); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_destroy_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS(device_interface->cuda_symbols, + cuEventDestroy((CUevent)event)); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + struct iree_hal_semaphore_t* semaphore, uint64_t value, + iree_hal_deferred_work_queue_native_event_t* out_event) { + return iree_hal_cuda_event_semaphore_acquire_timepoint_device_signal( + semaphore, value, (CUevent*)out_event); +} + +static bool +iree_hal_cuda_deferred_work_queue_device_interface_acquire_host_wait_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + struct iree_hal_semaphore_t* semaphore, uint64_t value, + iree_hal_deferred_work_queue_host_device_event_t* out_event) { + return iree_hal_cuda_semaphore_acquire_event_host_wait( + semaphore, value, (iree_hal_cuda_event_t**)out_event); +} + +static void +iree_hal_cuda_deferred_work_queue_device_interface_release_wait_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_host_device_event_t wait_event) { + iree_hal_cuda_event_release(wait_event); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_device_wait_on_host_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_host_device_event_t wait_event) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_CURESULT_TO_STATUS( + device_interface->cuda_symbols, + cuStreamWaitEvent( + device_interface->dispatch_cu_stream, + iree_hal_cuda_event_handle((iree_hal_cuda_event_t*)wait_event), 0), + "cuStreamWaitEvent"); +} + +static void* +iree_hal_cuda_deferred_work_queue_device_interface_native_event_from_wait_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_host_device_event_t event) { + return iree_hal_cuda_event_handle((iree_hal_cuda_event_t*)event); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_create_stream_command_buffer( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t categories, + iree_hal_command_buffer_t** out) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + return iree_hal_cuda_device_create_stream_command_buffer( + device_interface->device, mode, categories, 0, out); +} + +static iree_status_t +iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_command_buffer_t* command_buffer) { + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_cuda_deferred_work_queue_device_interface_t*)(base_device_interface); + iree_status_t status = iree_ok_status(); + if (iree_hal_cuda_stream_command_buffer_isa(command_buffer)) { + // Stream command buffer so nothing to do but notify it was submitted. + iree_hal_cuda_stream_notify_submitted_commands(command_buffer); + } else { + CUgraphExec exec = + iree_hal_cuda_graph_command_buffer_handle(command_buffer); + status = IREE_CURESULT_TO_STATUS( + device_interface->cuda_symbols, + cuGraphLaunch(exec, device_interface->dispatch_cu_stream)); + if (IREE_LIKELY(iree_status_is_ok(status))) { + iree_hal_cuda_graph_tracing_notify_submitted_commands(command_buffer); + } + } + return status; +} static iree_hal_cuda_device_t* iree_hal_cuda_device_cast( iree_hal_device_t* base_value) { @@ -152,9 +322,27 @@ static iree_status_t iree_hal_cuda_device_create_internal( device->dispatch_cu_stream = dispatch_stream; device->host_allocator = host_allocator; - iree_status_t status = iree_hal_cuda_pending_queue_actions_create( - cuda_symbols, cu_device, context, &device->block_pool, host_allocator, - &device->pending_queue_actions); + iree_hal_cuda_deferred_work_queue_device_interface_t* device_interface; + iree_status_t status = iree_allocator_malloc( + host_allocator, + sizeof(iree_hal_cuda_deferred_work_queue_device_interface_t), + (void**)&device_interface); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + iree_hal_device_release((iree_hal_device_t*)device); + return status; + } + device_interface->base.vtable = + &iree_hal_cuda_deferred_work_queue_device_interface_vtable; + device_interface->cu_context = context; + device_interface->cuda_symbols = cuda_symbols; + device_interface->cu_device = cu_device; + device_interface->device = (iree_hal_device_t*)device; + device_interface->dispatch_cu_stream = dispatch_stream; + device_interface->host_allocator = host_allocator; + + status = iree_hal_deferred_work_queue_create( + (iree_hal_deferred_work_queue_device_interface_t*)device_interface, + &device->block_pool, host_allocator, &device->work_queue); // Enable tracing for the (currently only) stream - no-op if disabled. if (iree_status_is_ok(status) && device->params.stream_tracing) { @@ -297,8 +485,7 @@ static void iree_hal_cuda_device_destroy(iree_hal_device_t* base_device) { IREE_TRACE_ZONE_BEGIN(z0); // Destroy the pending workload queue. - iree_hal_cuda_pending_queue_actions_destroy( - (iree_hal_resource_t*)device->pending_queue_actions); + iree_hal_deferred_work_queue_destroy(device->work_queue); // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -620,7 +807,7 @@ static iree_status_t iree_hal_cuda_device_create_semaphore( iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); return iree_hal_cuda_event_semaphore_create( initial_value, device->cuda_symbols, device->timepoint_pool, - device->pending_queue_actions, device->host_allocator, out_semaphore); + device->work_queue, device->host_allocator, out_semaphore); } static iree_hal_semaphore_compatibility_t @@ -765,15 +952,13 @@ static iree_status_t iree_hal_cuda_device_queue_execute( iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = iree_hal_cuda_pending_queue_actions_enqueue_execution( - base_device, device->dispatch_cu_stream, device->pending_queue_actions, - iree_hal_cuda_device_collect_tracing_context, device->tracing_context, - wait_semaphore_list, signal_semaphore_list, command_buffer_count, - command_buffers, binding_tables); + iree_status_t status = iree_hal_deferred_work_queue_enque( + device->work_queue, iree_hal_cuda_device_collect_tracing_context, + device->tracing_context, wait_semaphore_list, signal_semaphore_list, + command_buffer_count, command_buffers, binding_tables); if (iree_status_is_ok(status)) { - // Try to advance the pending workload queue. - status = iree_hal_cuda_pending_queue_actions_issue( - device->pending_queue_actions); + // Try to advance the deferred work queue. + status = iree_hal_deferred_work_queue_issue(device->work_queue); } IREE_TRACE_ZONE_END(z0); @@ -784,9 +969,8 @@ static iree_status_t iree_hal_cuda_device_queue_flush( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { iree_hal_cuda_device_t* device = iree_hal_cuda_device_cast(base_device); IREE_TRACE_ZONE_BEGIN(z0); - // Try to advance the pending workload queue. - iree_status_t status = - iree_hal_cuda_pending_queue_actions_issue(device->pending_queue_actions); + // Try to advance the deferred work queue. + iree_status_t status = iree_hal_deferred_work_queue_issue(device->work_queue); IREE_TRACE_ZONE_END(z0); return status; } @@ -850,3 +1034,34 @@ static const iree_hal_device_vtable_t iree_hal_cuda_device_vtable = { .profiling_flush = iree_hal_cuda_device_profiling_flush, .profiling_end = iree_hal_cuda_device_profiling_end, }; + +static const iree_hal_deferred_work_queue_device_interface_vtable_t + iree_hal_cuda_deferred_work_queue_device_interface_vtable = { + .destroy = iree_hal_cuda_deferred_work_queue_device_interface_destroy, + .bind_to_thread = + iree_hal_cuda_deferred_work_queue_device_interface_bind_to_thread, + .wait_native_event = + iree_hal_cuda_deferred_work_queue_device_interface_wait_native_event, + .create_native_event = + iree_hal_cuda_deferred_work_queue_device_interface_create_native_event, + .record_native_event = + iree_hal_cuda_deferred_work_queue_device_interface_record_native_event, + .synchronize_native_event = + iree_hal_cuda_deferred_work_queue_device_interface_synchronize_native_event, + .destroy_native_event = + iree_hal_cuda_deferred_work_queue_device_interface_destroy_native_event, + .semaphore_acquire_timepoint_device_signal_native_event = + iree_hal_cuda_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event, + .acquire_host_wait_event = + iree_hal_cuda_deferred_work_queue_device_interface_acquire_host_wait_event, + .device_wait_on_host_event = + iree_hal_cuda_deferred_work_queue_device_interface_device_wait_on_host_event, + .release_wait_event = + iree_hal_cuda_deferred_work_queue_device_interface_release_wait_event, + .native_event_from_wait_event = + iree_hal_cuda_deferred_work_queue_device_interface_native_event_from_wait_event, + .create_stream_command_buffer = + iree_hal_cuda_deferred_work_queue_device_interface_create_stream_command_buffer, + .submit_command_buffer = + iree_hal_cuda_deferred_work_queue_device_interface_submit_command_buffer, +}; diff --git a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c index a2bc05f7725c..fb86efe7e815 100644 --- a/runtime/src/iree/hal/drivers/cuda/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/cuda/event_semaphore.c @@ -11,6 +11,7 @@ #include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h" #include "iree/hal/drivers/cuda/cuda_status_util.h" #include "iree/hal/drivers/cuda/timepoint_pool.h" +#include "iree/hal/utils/deferred_work_queue.h" #include "iree/hal/utils/semaphore_base.h" typedef struct iree_hal_cuda_semaphore_t { @@ -28,7 +29,7 @@ typedef struct iree_hal_cuda_semaphore_t { // The list of pending queue actions that this semaphore need to advance on // new signaled values. - iree_hal_cuda_pending_queue_actions_t* pending_queue_actions; + iree_hal_deferred_work_queue_t* work_queue; // Guards value and status. We expect low contention on semaphores and since // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler @@ -57,11 +58,11 @@ static iree_hal_cuda_semaphore_t* iree_hal_cuda_semaphore_cast( iree_status_t iree_hal_cuda_event_semaphore_create( uint64_t initial_value, const iree_hal_cuda_dynamic_symbols_t* symbols, iree_hal_cuda_timepoint_pool_t* timepoint_pool, - iree_hal_cuda_pending_queue_actions_t* pending_queue_actions, - iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { + iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, + iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(timepoint_pool); - IREE_ASSERT_ARGUMENT(pending_queue_actions); + IREE_ASSERT_ARGUMENT(work_queue); IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); @@ -75,7 +76,7 @@ iree_status_t iree_hal_cuda_event_semaphore_create( semaphore->host_allocator = host_allocator; semaphore->symbols = symbols; semaphore->timepoint_pool = timepoint_pool; - semaphore->pending_queue_actions = pending_queue_actions; + semaphore->work_queue = work_queue; iree_slim_mutex_initialize(&semaphore->mutex); semaphore->current_value = initial_value; semaphore->failure_status = iree_ok_status(); @@ -149,10 +150,10 @@ static iree_status_t iree_hal_cuda_semaphore_signal( // Notify timepoints - note that this must happen outside the lock. iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK); - // Advance the pending queue actions if possible. This also must happen + // Advance the deferred work queue if possible. This also must happen // outside the lock to avoid nesting. - iree_status_t status = iree_hal_cuda_pending_queue_actions_issue( - semaphore->pending_queue_actions); + iree_status_t status = + iree_hal_deferred_work_queue_issue(semaphore->work_queue); IREE_TRACE_ZONE_END(z0); return status; @@ -188,10 +189,9 @@ static void iree_hal_cuda_semaphore_fail(iree_hal_semaphore_t* base_semaphore, iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, status_code); - // Advance the pending queue actions if possible. This also must happen + // Advance the deferred work queue if possible. This also must happen // outside the lock to avoid nesting. - status = iree_hal_cuda_pending_queue_actions_issue( - semaphore->pending_queue_actions); + status = iree_hal_deferred_work_queue_issue(semaphore->work_queue); iree_status_ignore(status); IREE_TRACE_ZONE_END(z0); diff --git a/runtime/src/iree/hal/drivers/cuda/event_semaphore.h b/runtime/src/iree/hal/drivers/cuda/event_semaphore.h index 74a7ffcf1f9d..e67d55fcfe54 100644 --- a/runtime/src/iree/hal/drivers/cuda/event_semaphore.h +++ b/runtime/src/iree/hal/drivers/cuda/event_semaphore.h @@ -12,8 +12,8 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h" -#include "iree/hal/drivers/cuda/pending_queue_actions.h" #include "iree/hal/drivers/cuda/timepoint_pool.h" +#include "iree/hal/utils/deferred_work_queue.h" #ifdef __cplusplus extern "C" { @@ -26,14 +26,14 @@ extern "C" { // allocated from the |timepoint_pool|. // // This semaphore is meant to be used together with a pending queue actions; it -// may advance the given |pending_queue_actions| if new values are signaled. +// may advance the given |work_queue| if new values are signaled. // // Thread-safe; multiple threads may signal/wait values on the same semaphore. iree_status_t iree_hal_cuda_event_semaphore_create( uint64_t initial_value, const iree_hal_cuda_dynamic_symbols_t* symbols, iree_hal_cuda_timepoint_pool_t* timepoint_pool, - iree_hal_cuda_pending_queue_actions_t* pending_queue_actions, - iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); + iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, + iree_hal_semaphore_t** out_semaphore); // Acquires a timepoint to signal the timeline to the given |to_value| from the // device. The underlying CUDA event is written into |out_event| for interacting diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c deleted file mode 100644 index 17af9684c179..000000000000 --- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c +++ /dev/null @@ -1,1494 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "iree/hal/drivers/cuda/pending_queue_actions.h" - -#include -#include - -#include "iree/base/api.h" -#include "iree/base/internal/arena.h" -#include "iree/base/internal/atomic_slist.h" -#include "iree/base/internal/atomics.h" -#include "iree/base/internal/synchronization.h" -#include "iree/base/internal/threading.h" -#include "iree/hal/api.h" -#include "iree/hal/drivers/cuda/cuda_device.h" -#include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h" -#include "iree/hal/drivers/cuda/cuda_status_util.h" -#include "iree/hal/drivers/cuda/event_pool.h" -#include "iree/hal/drivers/cuda/event_semaphore.h" -#include "iree/hal/drivers/cuda/graph_command_buffer.h" -#include "iree/hal/drivers/cuda/stream_command_buffer.h" -#include "iree/hal/utils/deferred_command_buffer.h" -#include "iree/hal/utils/resource_set.h" - -// The maximal number of CUevent objects a command buffer can wait. -#define IREE_HAL_CUDA_MAX_WAIT_EVENT_COUNT 32 - -//===----------------------------------------------------------------------===// -// Queue action -//===----------------------------------------------------------------------===// - -typedef enum iree_hal_cuda_queue_action_kind_e { - IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION, - // TODO: Add support for queue alloca and dealloca. -} iree_hal_cuda_queue_action_kind_t; - -typedef enum iree_hal_cuda_queue_action_state_e { - // The current action is active as waiting for or under execution. - IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE, - // The current action is done execution and waiting for destruction. - IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE, -} iree_hal_cuda_queue_action_state_t; - -// How many work items must complete in order for an action to complete. -// We keep track of the remaining work for an action so we don't exit worker -// threads prematurely. -// +1 for issuing an execution of an action. -// +1 for cleaning up a zombie action. -static const iree_host_size_t total_work_items_to_complete_an_action = 2; - -// A pending queue action. -// -// Note that this struct does not have internal synchronization; it's expected -// to work together with the pending action queue, which synchronizes accesses. -typedef struct iree_hal_cuda_queue_action_t { - // Intrusive doubly-linked list next entry pointer. - struct iree_hal_cuda_queue_action_t* next; - // Intrusive doubly-linked list previous entry pointer. - struct iree_hal_cuda_queue_action_t* prev; - - // The owning pending actions queue. We use its allocators and pools. - // Retained to make sure it outlives the current action. - iree_hal_cuda_pending_queue_actions_t* owning_actions; - - // The current state of this action. When an action is initially created it - // will be alive and enqueued to wait for releasing to the GPU. After done - // execution, it will be flipped into zombie state and enqueued again for - // destruction. - iree_hal_cuda_queue_action_state_t state; - // The callback to run after completing this action and before freeing - // all resources. Can be NULL. - iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback; - // User data to pass into the callback. - void* callback_user_data; - - iree_hal_cuda_queue_action_kind_t kind; - union { - struct { - iree_host_size_t count; - iree_hal_command_buffer_t** command_buffers; - iree_hal_buffer_binding_table_t* binding_tables; - } execution; - } payload; - - // The device from which to allocate CUDA stream-based command buffers for - // applying deferred command buffers. - iree_hal_device_t* device; - - // The stream to launch main GPU workload. - CUstream dispatch_cu_stream; - - // Resource set to retain all associated resources by the payload. - iree_hal_resource_set_t* resource_set; - - // Semaphore list to wait on for the payload to start on the GPU. - iree_hal_semaphore_list_t wait_semaphore_list; - // Semaphore list to signal after the payload completes on the GPU. - iree_hal_semaphore_list_t signal_semaphore_list; - - // Scratch fields for analyzing whether actions are ready to issue. - iree_hal_cuda_event_t* events[IREE_HAL_CUDA_MAX_WAIT_EVENT_COUNT]; - iree_host_size_t event_count; - // Whether the current action is still not ready for releasing to the GPU. - bool is_pending; -} iree_hal_cuda_queue_action_t; - -static void iree_hal_cuda_queue_action_fail_locked( - iree_hal_cuda_queue_action_t* action, iree_status_t status); - -static void iree_hal_cuda_queue_action_clear_events( - iree_hal_cuda_queue_action_t* action) { - for (iree_host_size_t i = 0; i < action->event_count; ++i) { - iree_hal_cuda_event_release(action->events[i]); - } - action->event_count = 0; -} - -static void iree_hal_cuda_queue_action_destroy( - iree_hal_cuda_queue_action_t* action); - -//===----------------------------------------------------------------------===// -// Queue action list -//===----------------------------------------------------------------------===// - -typedef struct iree_hal_cuda_queue_action_list_t { - iree_hal_cuda_queue_action_t* head; - iree_hal_cuda_queue_action_t* tail; -} iree_hal_cuda_queue_action_list_t; - -// Returns true if the action list is empty. -static inline bool iree_hal_cuda_queue_action_list_is_empty( - const iree_hal_cuda_queue_action_list_t* list) { - return list->head == NULL; -} - -static iree_hal_cuda_queue_action_t* iree_hal_cuda_queue_action_list_pop_front( - iree_hal_cuda_queue_action_list_t* list) { - IREE_ASSERT(list->head && list->tail); - - iree_hal_cuda_queue_action_t* action = list->head; - IREE_ASSERT(!action->prev); - list->head = action->next; - if (action->next) { - action->next->prev = NULL; - action->next = NULL; - } - if (list->tail == action) { - list->tail = NULL; - } - - return action; -} - -// Pushes |action| on to the end of the given action |list|. -static void iree_hal_cuda_queue_action_list_push_back( - iree_hal_cuda_queue_action_list_t* list, - iree_hal_cuda_queue_action_t* action) { - IREE_ASSERT(!action->next && !action->prev); - if (list->tail) { - list->tail->next = action; - } else { - list->head = action; - } - action->prev = list->tail; - list->tail = action; -} - -// Takes all actions from |available_list| and moves them into |ready_list|. -static void iree_hal_cuda_queue_action_list_take_all( - iree_hal_cuda_queue_action_list_t* available_list, - iree_hal_cuda_queue_action_list_t* ready_list) { - IREE_ASSERT_NE(available_list, ready_list); - ready_list->head = available_list->head; - ready_list->tail = available_list->tail; - available_list->head = NULL; - available_list->tail = NULL; -} - -static void iree_hal_cuda_queue_action_list_destroy( - iree_hal_cuda_queue_action_t* head_action) { - while (head_action) { - iree_hal_cuda_queue_action_t* next_action = head_action->next; - iree_hal_cuda_queue_action_destroy(head_action); - head_action = next_action; - } -} - -//===----------------------------------------------------------------------===// -// Ready-list processing -//===----------------------------------------------------------------------===// - -// Ready action atomic slist entry struct. -typedef struct iree_hal_cuda_entry_list_node_t { - iree_hal_cuda_queue_action_t* ready_list_head; - struct iree_hal_cuda_entry_list_node_t* next; -} iree_hal_cuda_entry_list_node_t; - -typedef struct iree_hal_cuda_entry_list_t { - iree_slim_mutex_t guard_mutex; - - iree_hal_cuda_entry_list_node_t* head IREE_GUARDED_BY(guard_mutex); - iree_hal_cuda_entry_list_node_t* tail IREE_GUARDED_BY(guard_mutex); -} iree_hal_cuda_entry_list_t; - -static iree_hal_cuda_entry_list_node_t* iree_hal_cuda_entry_list_pop( - iree_hal_cuda_entry_list_t* list) { - iree_hal_cuda_entry_list_node_t* out = NULL; - iree_slim_mutex_lock(&list->guard_mutex); - if (list->head) { - out = list->head; - list->head = list->head->next; - if (out == list->tail) { - list->tail = NULL; - } - } - iree_slim_mutex_unlock(&list->guard_mutex); - return out; -} - -void iree_hal_cuda_entry_list_push(iree_hal_cuda_entry_list_t* list, - iree_hal_cuda_entry_list_node_t* next) { - iree_slim_mutex_lock(&list->guard_mutex); - next->next = NULL; - if (list->tail) { - list->tail->next = next; - list->tail = next; - } else { - list->head = next; - list->tail = next; - } - iree_slim_mutex_unlock(&list->guard_mutex); -} - -static void iree_hal_cuda_ready_action_list_deinitialize( - iree_hal_cuda_entry_list_t* list, iree_allocator_t host_allocator) { - iree_hal_cuda_entry_list_node_t* head = list->head; - while (head) { - if (!head) break; - iree_hal_cuda_queue_action_list_destroy(head->ready_list_head); - list->head = head->next; - iree_allocator_free(host_allocator, head); - } - iree_slim_mutex_deinitialize(&list->guard_mutex); -} - -static void iree_hal_cuda_ready_action_list_initialize( - iree_hal_cuda_entry_list_t* list) { - list->head = NULL; - list->tail = NULL; - iree_slim_mutex_initialize(&list->guard_mutex); -} - -// Ready action atomic slist entry struct. -typedef struct iree_hal_cuda_completion_list_node_t { - // The callback and user data for that callback. To be called - // when the associated event has completed. - iree_status_t (*callback)(iree_status_t, void* user_data); - void* user_data; - // The event to wait for on the completion thread. - CUevent event; - // If this event was created just for the completion thread, and therefore - // needs to be cleaned up. - bool created_event; - struct iree_hal_cuda_completion_list_node_t* next; -} iree_hal_cuda_completion_list_node_t; - -typedef struct iree_hal_cuda_completion_list_t { - iree_slim_mutex_t guard_mutex; - iree_hal_cuda_completion_list_node_t* head IREE_GUARDED_BY(guard_mutex); - iree_hal_cuda_completion_list_node_t* tail IREE_GUARDED_BY(guard_mutex); -} iree_hal_cuda_completion_list_t; - -static iree_hal_cuda_completion_list_node_t* iree_hal_cuda_completion_list_pop( - iree_hal_cuda_completion_list_t* list) { - iree_hal_cuda_completion_list_node_t* out = NULL; - iree_slim_mutex_lock(&list->guard_mutex); - if (list->head) { - out = list->head; - list->head = list->head->next; - if (out == list->tail) { - list->tail = NULL; - } - } - iree_slim_mutex_unlock(&list->guard_mutex); - return out; -} - -void iree_hal_cuda_completion_list_push( - iree_hal_cuda_completion_list_t* list, - iree_hal_cuda_completion_list_node_t* next) { - iree_slim_mutex_lock(&list->guard_mutex); - next->next = NULL; - if (list->tail) { - list->tail->next = next; - list->tail = next; - } else { - list->head = next; - list->tail = next; - } - iree_slim_mutex_unlock(&list->guard_mutex); -} - -static void iree_hal_cuda_completion_list_initialize( - iree_hal_cuda_completion_list_t* list) { - list->head = NULL; - list->tail = NULL; - iree_slim_mutex_initialize(&list->guard_mutex); -} - -static void iree_hal_cuda_completion_list_deinitialize( - iree_hal_cuda_completion_list_t* list, - const iree_hal_cuda_dynamic_symbols_t* symbols, - iree_allocator_t host_allocator) { - iree_hal_cuda_completion_list_node_t* head = list->head; - while (head) { - if (head->created_event) { - IREE_CUDA_IGNORE_ERROR(symbols, cuEventDestroy(head->event)); - } - list->head = list->head->next; - iree_allocator_free(host_allocator, head); - } - iree_slim_mutex_deinitialize(&list->guard_mutex); -} - -static iree_hal_cuda_queue_action_t* iree_hal_cuda_entry_list_node_pop_front( - iree_hal_cuda_entry_list_node_t* list) { - IREE_ASSERT(list->ready_list_head); - - iree_hal_cuda_queue_action_t* action = list->ready_list_head; - IREE_ASSERT(!action->prev); - list->ready_list_head = action->next; - if (action->next) { - action->next->prev = NULL; - action->next = NULL; - } - - return action; -} - -static void iree_hal_cuda_entry_list_node_push_front( - iree_hal_cuda_entry_list_node_t* entry, - iree_hal_cuda_queue_action_t* action) { - IREE_ASSERT(!action->next && !action->prev); - - iree_hal_cuda_queue_action_t* head = entry->ready_list_head; - entry->ready_list_head = action; - if (head) { - action->next = head; - head->prev = action; - } -} - -// The ready-list processing worker's working/exiting state. -// -// States in the list has increasing priorities--meaning normally ones appearing -// earlier can overwrite ones appearing later without checking; but not the -// reverse order. -typedef enum iree_hal_cuda_worker_state_e { - IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING = 0, // Worker to any thread - IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING = 1, // Any to worker thread -} iree_hal_cuda_worker_state_t; - -// The data structure needed by a ready-list processing worker thread to issue -// ready actions to the GPU. -// -// This data structure is shared between the parent thread, which owns the -// whole pending actions queue, and the worker thread; so proper synchronization -// is needed to touch it from both sides. -// -// The parent thread should push a list of ready actions to ready_worklist, -// update worker_state, and give state_notification accordingly. -// The worker thread waits on the state_notification and checks worker_state, -// and pops from the ready_worklist to process. The worker thread also monitors -// worker_state and stops processing if requested by the parent thread. -typedef struct iree_hal_cuda_working_area_t { - // Notification from the parent thread to request worker state changes. - iree_notification_t state_notification; - iree_hal_cuda_entry_list_t ready_worklist; // atomic - iree_atomic_int32_t worker_state; // atomic -} iree_hal_cuda_working_area_t; - -// This data structure is shared by the parent thread. It is responsible -// for dispatching callbacks when work items complete. - -// This replaces the use of cuLaunchHostFunc, which causes the stream to block -// and wait for the CPU work to complete. It also picks up completed -// events with significantly less latency than cuLaunchHostFunc. -typedef struct iree_hal_cuda_completion_area_t { - // Notification from the parent thread to request completion state changes. - iree_notification_t state_notification; - iree_hal_cuda_completion_list_t completion_list; // atomic - iree_atomic_int32_t worker_state; // atomic -} iree_hal_cuda_completion_area_t; - -static void iree_hal_cuda_working_area_initialize( - iree_allocator_t host_allocator, CUdevice device, - const iree_hal_cuda_dynamic_symbols_t* symbols, - iree_hal_cuda_working_area_t* working_area) { - iree_notification_initialize(&working_area->state_notification); - iree_hal_cuda_ready_action_list_deinitialize(&working_area->ready_worklist, - host_allocator); - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); -} - -static void iree_hal_cuda_working_area_deinitialize( - iree_hal_cuda_working_area_t* working_area, - iree_allocator_t host_allocator) { - iree_hal_cuda_ready_action_list_deinitialize(&working_area->ready_worklist, - host_allocator); - iree_notification_deinitialize(&working_area->state_notification); -} - -static void iree_hal_cuda_completion_area_initialize( - iree_allocator_t host_allocator, CUdevice device, - const iree_hal_cuda_dynamic_symbols_t* symbols, - iree_hal_cuda_completion_area_t* completion_area) { - iree_notification_initialize(&completion_area->state_notification); - iree_hal_cuda_completion_list_initialize(&completion_area->completion_list); - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); -} - -static void iree_hal_cuda_completion_area_deinitialize( - iree_hal_cuda_completion_area_t* completion_area, - const iree_hal_cuda_dynamic_symbols_t* symbols, - iree_allocator_t host_allocator) { - iree_hal_cuda_completion_list_deinitialize(&completion_area->completion_list, - symbols, host_allocator); - iree_notification_deinitialize(&completion_area->state_notification); -} - -// The main function for the ready-list processing worker thread. -static int iree_hal_cuda_worker_execute( - iree_hal_cuda_pending_queue_actions_t* actions); - -static int iree_hal_cuda_completion_execute( - iree_hal_cuda_pending_queue_actions_t* actions); - -//===----------------------------------------------------------------------===// -// Pending queue actions -//===----------------------------------------------------------------------===// - -struct iree_hal_cuda_pending_queue_actions_t { - // Abstract resource used for injecting reference counting and vtable; - // must be at offset 0. - iree_hal_resource_t resource; - - // The allocator used to create the timepoint pool. - iree_allocator_t host_allocator; - // The block pool to allocate resource sets from. - iree_arena_block_pool_t* block_pool; - - // The symbols used to create and destroy CUevent objects. - const iree_hal_cuda_dynamic_symbols_t* symbols; - - // Non-recursive mutex guarding access. - iree_slim_mutex_t action_mutex; - - // The double-linked list of pending actions. - iree_hal_cuda_queue_action_list_t action_list IREE_GUARDED_BY(action_mutex); - - // The worker thread that monitors incoming requests and issues ready actions - // to the GPU. - iree_thread_t* worker_thread; - - // Worker thread to wait on completion events instead of running - // synchronous completion callbacks - iree_thread_t* completion_thread; - - // The worker's working area; data exchange place with the parent thread. - iree_hal_cuda_working_area_t working_area; - - // Completion thread's working area. - iree_hal_cuda_completion_area_t completion_area; - - // Atomic of type iree_status_t. It is a sticky error. - // Once set with an error, all subsequent actions that have not completed - // will fail with this error. - iree_status_t status IREE_GUARDED_BY(action_mutex); - - // The associated cuda device. - CUdevice device; - CUcontext cuda_context; - - // The number of asynchronous work items that are scheduled and not - // complete. - // These are - // * the number of actions issued. - // * the number of pending action cleanups. - // The work and completion threads can exit only when there are no more - // pending work items. - iree_host_size_t pending_work_items_count IREE_GUARDED_BY(action_mutex); - - // The owner can request an exit of the worker threads. - // Once all pending enqueued work is complete the threads will exit. - // No actions can be enqueued after requesting an exit. - bool exit_requested IREE_GUARDED_BY(action_mutex); -}; - -iree_status_t iree_hal_cuda_pending_queue_actions_create( - const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device, - CUcontext context, iree_arena_block_pool_t* block_pool, - iree_allocator_t host_allocator, - iree_hal_cuda_pending_queue_actions_t** out_actions) { - IREE_ASSERT_ARGUMENT(symbols); - IREE_ASSERT_ARGUMENT(block_pool); - IREE_ASSERT_ARGUMENT(out_actions); - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_cuda_pending_queue_actions_t* actions = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(host_allocator, sizeof(*actions), - (void**)&actions)); - actions->host_allocator = host_allocator; - actions->block_pool = block_pool; - actions->symbols = symbols; - actions->device = device; - actions->cuda_context = context; - iree_slim_mutex_initialize(&actions->action_mutex); - memset(&actions->action_list, 0, sizeof(actions->action_list)); - - // Initialize the working area for the ready-list processing worker. - iree_hal_cuda_working_area_t* working_area = &actions->working_area; - iree_hal_cuda_working_area_initialize(host_allocator, device, symbols, - working_area); - - iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area; - iree_hal_cuda_completion_area_initialize(host_allocator, device, symbols, - completion_area); - - // Create the ready-list processing worker itself. - iree_thread_create_params_t params; - memset(¶ms, 0, sizeof(params)); - params.name = IREE_SV("iree-cuda-queue-worker"); - params.create_suspended = false; - iree_status_t status = iree_thread_create( - (iree_thread_entry_t)iree_hal_cuda_worker_execute, actions, params, - actions->host_allocator, &actions->worker_thread); - - params.name = IREE_SV("iree-cuda-queue-completion"); - params.create_suspended = false; - if (iree_status_is_ok(status)) { - status = iree_thread_create( - (iree_thread_entry_t)iree_hal_cuda_completion_execute, actions, params, - actions->host_allocator, &actions->completion_thread); - } - - if (iree_status_is_ok(status)) { - *out_actions = actions; - } else { - iree_hal_cuda_pending_queue_actions_destroy((iree_hal_resource_t*)actions); - } - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -static iree_hal_cuda_pending_queue_actions_t* -iree_hal_cuda_pending_queue_actions_cast(iree_hal_resource_t* base_value) { - return (iree_hal_cuda_pending_queue_actions_t*)base_value; -} - -static void iree_hal_cuda_pending_queue_actions_notify_worker_thread( - iree_hal_cuda_working_area_t* working_area) { - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING, - iree_memory_order_release); - iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS); -} - -static void iree_hal_cuda_pending_queue_actions_notify_completion_thread( - iree_hal_cuda_completion_area_t* completion_area) { - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING, - iree_memory_order_release); - iree_notification_post(&completion_area->state_notification, - IREE_ALL_WAITERS); -} - -// Notifies worker and completion threads that there is work available to -// process. -static void iree_hal_cuda_pending_queue_actions_notify_threads( - iree_hal_cuda_pending_queue_actions_t* actions) { - iree_hal_cuda_pending_queue_actions_notify_worker_thread( - &actions->working_area); - iree_hal_cuda_pending_queue_actions_notify_completion_thread( - &actions->completion_area); -} - -static void iree_hal_cuda_pending_queue_actions_request_exit( - iree_hal_cuda_pending_queue_actions_t* actions) { - iree_slim_mutex_lock(&actions->action_mutex); - actions->exit_requested = true; - iree_slim_mutex_unlock(&actions->action_mutex); - - iree_hal_cuda_pending_queue_actions_notify_threads(actions); -} - -void iree_hal_cuda_pending_queue_actions_destroy( - iree_hal_resource_t* base_actions) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_cuda_pending_queue_actions_t* actions = - iree_hal_cuda_pending_queue_actions_cast(base_actions); - iree_allocator_t host_allocator = actions->host_allocator; - - // Request the workers to exit. - iree_hal_cuda_pending_queue_actions_request_exit(actions); - - iree_thread_join(actions->worker_thread); - iree_thread_release(actions->worker_thread); - - iree_thread_join(actions->completion_thread); - iree_thread_release(actions->completion_thread); - - iree_hal_cuda_working_area_deinitialize(&actions->working_area, - actions->host_allocator); - iree_hal_cuda_completion_area_deinitialize( - &actions->completion_area, actions->symbols, actions->host_allocator); - - iree_slim_mutex_deinitialize(&actions->action_mutex); - iree_hal_cuda_queue_action_list_destroy(actions->action_list.head); - iree_allocator_free(host_allocator, actions); - - IREE_TRACE_ZONE_END(z0); -} - -static void iree_hal_cuda_queue_action_destroy( - iree_hal_cuda_queue_action_t* action) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions; - iree_allocator_t host_allocator = actions->host_allocator; - - // Call user provided callback before releasing any resource. - if (action->cleanup_callback) { - action->cleanup_callback(action->callback_user_data); - } - - // Only release resources after callbacks have been issued. - iree_hal_resource_set_free(action->resource_set); - - iree_hal_cuda_queue_action_clear_events(action); - - iree_hal_resource_release(actions); - - iree_allocator_free(host_allocator, action); - - IREE_TRACE_ZONE_END(z0); -} - -static void iree_hal_cuda_queue_decrement_work_items_count( - iree_hal_cuda_pending_queue_actions_t* actions) { - iree_slim_mutex_lock(&actions->action_mutex); - --actions->pending_work_items_count; - iree_slim_mutex_unlock(&actions->action_mutex); -} - -iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution( - iree_hal_device_t* device, CUstream dispatch_stream, - iree_hal_cuda_pending_queue_actions_t* actions, - iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback, - void* callback_user_data, - const iree_hal_semaphore_list_t wait_semaphore_list, - const iree_hal_semaphore_list_t signal_semaphore_list, - iree_host_size_t command_buffer_count, - iree_hal_command_buffer_t* const* command_buffers, - iree_hal_buffer_binding_table_t const* binding_tables) { - IREE_ASSERT_ARGUMENT(actions); - IREE_ASSERT_ARGUMENT(command_buffer_count == 0 || command_buffers); - IREE_TRACE_ZONE_BEGIN(z0); - - // Embed captured tables in the action allocation. - iree_hal_cuda_queue_action_t* action = NULL; - const iree_host_size_t wait_semaphore_list_size = - wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores) + - wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values); - const iree_host_size_t signal_semaphore_list_size = - signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores) + - signal_semaphore_list.count * - sizeof(*signal_semaphore_list.payload_values); - const iree_host_size_t command_buffers_size = - command_buffer_count * sizeof(*action->payload.execution.command_buffers); - iree_host_size_t binding_tables_size = 0; - iree_host_size_t binding_table_elements_size = 0; - if (binding_tables) { - binding_tables_size = command_buffer_count * sizeof(*binding_tables); - for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { - binding_table_elements_size += - binding_tables[i].count * sizeof(*binding_tables[i].bindings); - } - } - const iree_host_size_t payload_size = - command_buffers_size + binding_tables_size + binding_table_elements_size; - const iree_host_size_t total_action_size = - sizeof(*action) + wait_semaphore_list_size + signal_semaphore_list_size + - payload_size; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_allocator_malloc(actions->host_allocator, total_action_size, - (void**)&action)); - uint8_t* action_ptr = (uint8_t*)action + sizeof(*action); - - action->owning_actions = actions; - action->state = IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE; - action->cleanup_callback = cleanup_callback; - action->callback_user_data = callback_user_data; - action->kind = IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION; - action->device = device; - action->dispatch_cu_stream = dispatch_stream; - - // Initialize scratch fields. - action->event_count = 0; - action->is_pending = true; - - // Copy wait list for later access. - action->wait_semaphore_list.count = wait_semaphore_list.count; - action->wait_semaphore_list.semaphores = (iree_hal_semaphore_t**)action_ptr; - memcpy(action->wait_semaphore_list.semaphores, wait_semaphore_list.semaphores, - wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores)); - action->wait_semaphore_list.payload_values = - (uint64_t*)(action_ptr + wait_semaphore_list.count * - sizeof(*wait_semaphore_list.semaphores)); - memcpy( - action->wait_semaphore_list.payload_values, - wait_semaphore_list.payload_values, - wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values)); - action_ptr += wait_semaphore_list_size; - - // Copy signal list for later access. - action->signal_semaphore_list.count = signal_semaphore_list.count; - action->signal_semaphore_list.semaphores = (iree_hal_semaphore_t**)action_ptr; - memcpy( - action->signal_semaphore_list.semaphores, - signal_semaphore_list.semaphores, - signal_semaphore_list.count * sizeof(*signal_semaphore_list.semaphores)); - action->signal_semaphore_list.payload_values = - (uint64_t*)(action_ptr + signal_semaphore_list.count * - sizeof(*signal_semaphore_list.semaphores)); - memcpy(action->signal_semaphore_list.payload_values, - signal_semaphore_list.payload_values, - signal_semaphore_list.count * - sizeof(*signal_semaphore_list.payload_values)); - action_ptr += signal_semaphore_list_size; - - // Copy the execution resources for later access. - action->payload.execution.count = command_buffer_count; - action->payload.execution.command_buffers = - (iree_hal_command_buffer_t**)action_ptr; - memcpy(action->payload.execution.command_buffers, command_buffers, - command_buffers_size); - action_ptr += command_buffers_size; - - // Retain all command buffers and semaphores. - iree_status_t status = iree_hal_resource_set_allocate(actions->block_pool, - &action->resource_set); - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert(action->resource_set, - wait_semaphore_list.count, - wait_semaphore_list.semaphores); - } - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert(action->resource_set, - signal_semaphore_list.count, - signal_semaphore_list.semaphores); - } - if (iree_status_is_ok(status)) { - status = iree_hal_resource_set_insert( - action->resource_set, command_buffer_count, command_buffers); - } - - // Copy binding tables and retain all bindings. - if (iree_status_is_ok(status) && binding_table_elements_size > 0) { - action->payload.execution.binding_tables = - (iree_hal_buffer_binding_table_t*)action_ptr; - action_ptr += binding_tables_size; - iree_hal_buffer_binding_t* binding_element_ptr = - (iree_hal_buffer_binding_t*)action_ptr; - for (iree_host_size_t i = 0; i < command_buffer_count; ++i) { - iree_host_size_t element_count = binding_tables[i].count; - iree_hal_buffer_binding_table_t* target_table = - &action->payload.execution.binding_tables[i]; - target_table->count = element_count; - target_table->bindings = binding_element_ptr; - memcpy((void*)target_table->bindings, binding_tables[i].bindings, - element_count * sizeof(*binding_element_ptr)); - binding_element_ptr += element_count; - - // Bulk insert all bindings into the resource set. This will keep the - // referenced buffers live until the action has completed. Note that if we - // fail here we need to clean up the resource set below before returning. - status = iree_hal_resource_set_insert_strided( - action->resource_set, element_count, target_table->bindings, - offsetof(iree_hal_buffer_binding_t, buffer), - sizeof(iree_hal_buffer_binding_t)); - if (!iree_status_is_ok(status)) break; - } - } else { - action->payload.execution.binding_tables = NULL; - } - - if (iree_status_is_ok(status)) { - // Now everything is okay and we can enqueue the action. - iree_slim_mutex_lock(&actions->action_mutex); - if (actions->exit_requested) { - status = iree_make_status( - IREE_STATUS_ABORTED, - "can not issue more executions, exit already requested"); - iree_hal_cuda_queue_action_fail_locked(action, status); - } else { - iree_hal_cuda_queue_action_list_push_back(&actions->action_list, action); - // One work item is the callback that makes it across from the - // completion thread. - // The other is the cleanup of the action. - actions->pending_work_items_count += - total_work_items_to_complete_an_action; - } - iree_slim_mutex_unlock(&actions->action_mutex); - } else { - iree_hal_resource_set_free(action->resource_set); - iree_allocator_free(actions->host_allocator, action); - } - - IREE_TRACE_ZONE_END(z0); - return status; -} - -// Does not consume |status|. -static void iree_hal_cuda_pending_queue_actions_fail_status_locked( - iree_hal_cuda_pending_queue_actions_t* actions, iree_status_t status) { - if (iree_status_is_ok(actions->status) && status != actions->status) { - actions->status = iree_status_clone(status); - } -} - -// Fails and destroys the action. -// Does not consume |status|. -// Decrements pending work items count accordingly based on the unfulfilled -// number of work items. -static void iree_hal_cuda_queue_action_fail_locked( - iree_hal_cuda_queue_action_t* action, iree_status_t status) { - IREE_ASSERT(!iree_status_is_ok(status)); - iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions; - - // Unlock since failing the semaphore will use |actions|. - iree_slim_mutex_unlock(&actions->action_mutex); - iree_hal_semaphore_list_fail(action->signal_semaphore_list, - iree_status_clone(status)); - - iree_host_size_t work_items_remaining = 0; - switch (action->state) { - case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE: - work_items_remaining = total_work_items_to_complete_an_action; - break; - case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE: - work_items_remaining = 1; - break; - default: - // Someone forgot to handle all enum values? - iree_abort(); - } - iree_slim_mutex_lock(&actions->action_mutex); - action->owning_actions->pending_work_items_count -= work_items_remaining; - iree_hal_cuda_pending_queue_actions_fail_status_locked(actions, status); - iree_hal_cuda_queue_action_destroy(action); -} - -// Fails and destroys all actions. -// Does not consume |status|. -static void iree_hal_cuda_queue_action_fail( - iree_hal_cuda_queue_action_t* action, iree_status_t status) { - iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions; - iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_cuda_queue_action_fail_locked(action, status); - iree_slim_mutex_unlock(&actions->action_mutex); -} - -// Fails and destroys all actions. -// Does not consume |status|. -static void iree_hal_cuda_queue_action_raw_list_fail_locked( - iree_hal_cuda_queue_action_t* head_action, iree_status_t status) { - while (head_action) { - iree_hal_cuda_queue_action_t* next_action = head_action->next; - iree_hal_cuda_queue_action_fail_locked(head_action, status); - head_action = next_action; - } -} - -// Fails and destroys all actions. -// Does not consume |status|. -static void iree_hal_cuda_ready_action_list_fail_locked( - iree_hal_cuda_entry_list_t* list, iree_status_t status) { - iree_hal_cuda_entry_list_node_t* entry = iree_hal_cuda_entry_list_pop(list); - while (entry) { - iree_hal_cuda_queue_action_raw_list_fail_locked(entry->ready_list_head, - status); - entry = iree_hal_cuda_entry_list_pop(list); - } -} - -// Fails and destroys all actions. -// Does not consume |status|. -static void iree_hal_cuda_queue_action_list_fail_locked( - iree_hal_cuda_queue_action_list_t* list, iree_status_t status) { - iree_hal_cuda_queue_action_t* action; - if (iree_hal_cuda_queue_action_list_is_empty(list)) { - return; - } - do { - action = iree_hal_cuda_queue_action_list_pop_front(list); - iree_hal_cuda_queue_action_fail_locked(action, status); - } while (action); -} - -// Fails and destroys all actions and sets status of |actions|. -// Does not consume |status|. -// Assumes the caller is holding the action_mutex. -static void iree_hal_cuda_pending_queue_actions_fail_locked( - iree_hal_cuda_pending_queue_actions_t* actions, iree_status_t status) { - iree_hal_cuda_pending_queue_actions_fail_status_locked(actions, status); - iree_hal_cuda_queue_action_list_fail_locked(&actions->action_list, status); - iree_hal_cuda_ready_action_list_fail_locked( - &actions->working_area.ready_worklist, status); -} - -// Does not consume |status|. -static void iree_hal_cuda_pending_queue_actions_fail( - iree_hal_cuda_pending_queue_actions_t* actions, iree_status_t status) { - iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_cuda_pending_queue_actions_fail_locked(actions, status); - iree_slim_mutex_unlock(&actions->action_mutex); -} - -// Releases resources after action completion on the GPU and advances timeline -// and pending actions queue. -static iree_status_t iree_hal_cuda_execution_device_signal_host_callback( - iree_status_t status, void* user_data) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_cuda_queue_action_t* action = - (iree_hal_cuda_queue_action_t*)user_data; - IREE_ASSERT_EQ(action->kind, IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION); - IREE_ASSERT_EQ(action->state, IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE); - iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions; - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_queue_action_fail(action, status); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); - } - - // Need to signal the list before zombifying the action, because in the mean - // time someone else may issue the pending queue actions. - // If we push first to the pending actions list, the cleanup of this action - // may run while we are still using the semaphore list, causing a crash. - status = iree_hal_semaphore_list_signal(action->signal_semaphore_list); - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_queue_action_fail(action, status); - IREE_TRACE_ZONE_END(z0); - return status; - } - - // Flip the action state to zombie and enqueue it again so that we can let - // the worker thread clean it up. Note that this is necessary because cleanup - // may involve GPU API calls like buffer releasing or unregistering, so we can - // not inline it here. - action->state = IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE; - iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_cuda_queue_action_list_push_back(&actions->action_list, action); - // The callback (work item) is complete. - --actions->pending_work_items_count; - iree_slim_mutex_unlock(&actions->action_mutex); - - // We need to trigger execution of this action again, so it gets cleaned up. - status = iree_hal_cuda_pending_queue_actions_issue(actions); - - IREE_TRACE_ZONE_END(z0); - return status; -} - -// Issues the given kernel dispatch |action| to the GPU. -static iree_status_t iree_hal_cuda_pending_queue_actions_issue_execution( - iree_hal_cuda_queue_action_t* action) { - IREE_ASSERT_EQ(action->kind, IREE_HAL_CUDA_QUEUE_ACTION_TYPE_EXECUTION); - IREE_ASSERT_EQ(action->is_pending, false); - iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions; - const iree_hal_cuda_dynamic_symbols_t* symbols = actions->symbols; - IREE_TRACE_ZONE_BEGIN(z0); - - // No need to lock given that this action is already detched from the pending - // actions list; so only this thread is seeing it now. - - // First wait all the device CUevent in the dispatch stream. - for (iree_host_size_t i = 0; i < action->event_count; ++i) { - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, - cuStreamWaitEvent(action->dispatch_cu_stream, - iree_hal_cuda_event_handle(action->events[i]), - CU_EVENT_WAIT_DEFAULT), - "cuStreamWaitEvent"); - } - - // Then launch all command buffers to the dispatch stream. - IREE_TRACE_ZONE_BEGIN(z_dispatch_command_buffers); - IREE_TRACE_ZONE_APPEND_TEXT(z_dispatch_command_buffers, - "dispatch_command_buffers"); - for (iree_host_size_t i = 0; i < action->payload.execution.count; ++i) { - iree_hal_command_buffer_t* command_buffer = - action->payload.execution.command_buffers[i]; - iree_hal_buffer_binding_table_t binding_table = - action->payload.execution.binding_tables - ? action->payload.execution.binding_tables[i] - : iree_hal_buffer_binding_table_empty(); - if (iree_hal_cuda_stream_command_buffer_isa(command_buffer)) { - // Notify that the commands were "submitted" so we can - // make sure to clean up our trace events. - iree_hal_cuda_stream_notify_submitted_commands(command_buffer); - } else if (iree_hal_cuda_graph_command_buffer_isa(command_buffer)) { - CUgraphExec exec = - iree_hal_cuda_graph_command_buffer_handle(command_buffer); - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, cuGraphLaunch(exec, action->dispatch_cu_stream), - "cuGraphLaunch"); - iree_hal_cuda_graph_tracing_notify_submitted_commands(command_buffer); - } else { - iree_hal_command_buffer_t* stream_command_buffer = NULL; - iree_hal_command_buffer_mode_t mode = - iree_hal_command_buffer_mode(command_buffer) | - IREE_HAL_COMMAND_BUFFER_MODE_ONE_SHOT | - IREE_HAL_COMMAND_BUFFER_MODE_ALLOW_INLINE_EXECUTION | - // NOTE: we need to validate if a binding table is provided as the - // bindings were not known when it was originally recorded. - (iree_hal_buffer_binding_table_is_empty(binding_table) - ? IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED - : 0); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_device_create_stream_command_buffer( - action->device, mode, IREE_HAL_COMMAND_CATEGORY_ANY, - /*binding_capacity=*/0, &stream_command_buffer)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_resource_set_insert(action->resource_set, 1, - &stream_command_buffer)); - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_deferred_command_buffer_apply( - command_buffer, stream_command_buffer, binding_table)); - iree_hal_cuda_stream_notify_submitted_commands(stream_command_buffer); - // The stream_command_buffer is going to be retained by - // the action->resource_set and deleted after the action - // completes. - iree_hal_resource_release(stream_command_buffer); - } - } - IREE_TRACE_ZONE_END(z_dispatch_command_buffers); - - CUevent completion_event = NULL; - // Last record CUevent signals in the dispatch stream. - for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) { - // Grab a CUevent for this semaphore value signaling. - CUevent event = NULL; - IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_cuda_event_semaphore_acquire_timepoint_device_signal( - action->signal_semaphore_list.semaphores[i], - action->signal_semaphore_list.payload_values[i], &event)); - - // Record the event signaling in the dispatch stream. - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, cuEventRecord(event, action->dispatch_cu_stream), - "cuEventRecord"); - completion_event = event; - } - - bool created_event = false; - // In the case where we issue an execution and there are signal semaphores - // we can re-use those as a wait event. However if there are no signals - // then we create one. In my testing this is not a common case. - if (IREE_UNLIKELY(!completion_event)) { - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, cuEventCreate(&completion_event, CU_EVENT_DISABLE_TIMING), - "cuEventCreate"); - created_event = true; - } - - IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, cuEventRecord(completion_event, action->dispatch_cu_stream), - "cuEventRecord"); - - iree_hal_cuda_completion_list_node_t* entry = NULL; - // TODO: avoid host allocator malloc; use some pool for the allocation. - iree_status_t status = iree_allocator_malloc(actions->host_allocator, - sizeof(*entry), (void**)&entry); - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - IREE_TRACE_ZONE_END(z0); - return status; - } - - // Now push the ready list to the worker and have it to issue the actions to - // the GPU. - entry->event = completion_event; - entry->created_event = created_event; - entry->callback = iree_hal_cuda_execution_device_signal_host_callback; - entry->user_data = action; - iree_hal_cuda_completion_list_push(&actions->completion_area.completion_list, - entry); - - iree_hal_cuda_pending_queue_actions_notify_completion_thread( - &actions->completion_area); - - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); -} - -// Performs the given cleanup |action| on the CPU. -static void iree_hal_cuda_pending_queue_actions_issue_cleanup( - iree_hal_cuda_queue_action_t* action) { - iree_hal_cuda_pending_queue_actions_t* actions = action->owning_actions; - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_cuda_queue_action_destroy(action); - - // Now we fully executed and cleaned up this action. Decrease the work items - // counter. - iree_hal_cuda_queue_decrement_work_items_count(actions); - - IREE_TRACE_ZONE_END(z0); -} - -iree_status_t iree_hal_cuda_pending_queue_actions_issue( - iree_hal_cuda_pending_queue_actions_t* actions) { - IREE_TRACE_ZONE_BEGIN(z0); - - iree_hal_cuda_queue_action_list_t pending_list = {NULL, NULL}; - iree_hal_cuda_queue_action_list_t ready_list = {NULL, NULL}; - - iree_slim_mutex_lock(&actions->action_mutex); - - if (iree_hal_cuda_queue_action_list_is_empty(&actions->action_list)) { - iree_slim_mutex_unlock(&actions->action_mutex); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); - } - - if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) { - iree_hal_cuda_queue_action_list_fail_locked(&actions->action_list, - actions->status); - iree_slim_mutex_unlock(&actions->action_mutex); - IREE_TRACE_ZONE_END(z0); - return iree_ok_status(); - } - - iree_status_t status = iree_ok_status(); - // Scan through the list and categorize actions into pending and ready lists. - while (!iree_hal_cuda_queue_action_list_is_empty(&actions->action_list)) { - iree_hal_cuda_queue_action_t* action = - iree_hal_cuda_queue_action_list_pop_front(&actions->action_list); - - iree_hal_semaphore_t** semaphores = action->wait_semaphore_list.semaphores; - uint64_t* values = action->wait_semaphore_list.payload_values; - - action->is_pending = false; - bool action_failed = false; - - // Cleanup actions are immediately ready to release. Otherwise, look at all - // wait semaphores to make sure that they are either already ready or we can - // wait on a device event. - if (action->state == IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE) { - for (iree_host_size_t i = 0; i < action->wait_semaphore_list.count; ++i) { - // If this semaphore has already signaled past the desired value, we can - // just ignore it. - uint64_t value = 0; - iree_status_t semaphore_status = - iree_hal_semaphore_query(semaphores[i], &value); - if (IREE_UNLIKELY(!iree_status_is_ok(semaphore_status))) { - iree_hal_cuda_queue_action_fail_locked(action, semaphore_status); - iree_status_ignore(semaphore_status); - action_failed = true; - break; - } - if (value >= values[i]) { - // No need to wait on this timepoint as it has already occurred and - // we can remove it from the wait list. - iree_hal_semaphore_list_erase(&action->wait_semaphore_list, i); - --i; - continue; - } - - // Try to acquire a CUDA event from an existing device signal timepoint. - // If so, we can use that event to wait on the device. - // Otherwise, this action is still not ready for execution. - // Before issuing recording on a stream, an event represents an empty - // set of work so waiting on it will just return success. - // Here we must guarantee the CUDA event is indeed recorded, which means - // it's associated with some already present device signal timepoint on - // the semaphore timeline. - iree_hal_cuda_event_t* wait_event = NULL; - if (!iree_hal_cuda_semaphore_acquire_event_host_wait( - semaphores[i], values[i], &wait_event)) { - action->is_pending = true; - break; - } - if (IREE_UNLIKELY(action->event_count >= - IREE_HAL_CUDA_MAX_WAIT_EVENT_COUNT)) { - status = iree_make_status( - IREE_STATUS_RESOURCE_EXHAUSTED, - "exceeded maximum queue action wait event limit"); - iree_hal_cuda_event_release(wait_event); - if (iree_status_is_ok(actions->status)) { - actions->status = status; - } - iree_hal_cuda_queue_action_fail_locked(action, status); - break; - } - action->events[action->event_count++] = wait_event; - - // Remove the wait timepoint as we have a corresponding event that we - // will wait on. - iree_hal_semaphore_list_erase(&action->wait_semaphore_list, i); - --i; - } - } - - if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) { - if (!action_failed) { - iree_hal_cuda_queue_action_fail_locked(action, actions->status); - } - iree_hal_cuda_queue_action_list_fail_locked(&actions->action_list, - actions->status); - break; - } - - if (action_failed) { - break; - } - - if (action->is_pending) { - iree_hal_cuda_queue_action_list_push_back(&pending_list, action); - } else { - iree_hal_cuda_queue_action_list_push_back(&ready_list, action); - } - } - - // Preserve pending timepoints. - actions->action_list = pending_list; - - iree_slim_mutex_unlock(&actions->action_mutex); - - if (ready_list.head == NULL) { - // Nothing ready yet. Just return. - IREE_TRACE_ZONE_END(z0); - return status; - } - - iree_hal_cuda_entry_list_node_t* entry = NULL; - // TODO: avoid host allocator malloc; use some pool for the allocation. - if (iree_status_is_ok(status)) { - status = iree_allocator_malloc(actions->host_allocator, sizeof(*entry), - (void**)&entry); - } - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_cuda_pending_queue_actions_fail_status_locked(actions, status); - iree_hal_cuda_queue_action_list_fail_locked(&ready_list, status); - iree_slim_mutex_unlock(&actions->action_mutex); - IREE_TRACE_ZONE_END(z0); - return status; - } - - // Now push the ready list to the worker and have it to issue the actions to - // the GPU. - entry->ready_list_head = ready_list.head; - iree_hal_cuda_entry_list_push(&actions->working_area.ready_worklist, entry); - - iree_hal_cuda_pending_queue_actions_notify_worker_thread( - &actions->working_area); - - IREE_TRACE_ZONE_END(z0); - return status; -} - -//===----------------------------------------------------------------------===// -// Worker routines -//===----------------------------------------------------------------------===// - -static bool iree_hal_cuda_worker_has_incoming_request( - iree_hal_cuda_working_area_t* working_area) { - iree_hal_cuda_worker_state_t value = iree_atomic_load_int32( - &working_area->worker_state, iree_memory_order_acquire); - return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING; -} - -static bool iree_hal_cuda_completion_has_incoming_request( - iree_hal_cuda_completion_area_t* completion_area) { - iree_hal_cuda_worker_state_t value = iree_atomic_load_int32( - &completion_area->worker_state, iree_memory_order_acquire); - return value == IREE_HAL_CUDA_WORKER_STATE_WORKLOAD_PENDING; -} - -// Processes all ready actions in the given |worklist|. -static void iree_hal_cuda_worker_process_ready_list( - iree_hal_cuda_pending_queue_actions_t* actions) { - IREE_TRACE_ZONE_BEGIN(z0); - - iree_slim_mutex_lock(&actions->action_mutex); - iree_status_t status = actions->status; - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_ready_action_list_fail_locked( - &actions->working_area.ready_worklist, status); - iree_slim_mutex_unlock(&actions->action_mutex); - iree_status_ignore(status); - return; - } - iree_slim_mutex_unlock(&actions->action_mutex); - - while (true) { - iree_hal_cuda_entry_list_node_t* entry = - iree_hal_cuda_entry_list_pop(&actions->working_area.ready_worklist); - if (!entry) break; - - // Process the current batch of ready actions. - while (entry->ready_list_head) { - iree_hal_cuda_queue_action_t* action = - iree_hal_cuda_entry_list_node_pop_front(entry); - switch (action->state) { - case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ALIVE: - status = iree_hal_cuda_pending_queue_actions_issue_execution(action); - break; - case IREE_HAL_CUDA_QUEUE_ACTION_STATE_ZOMBIE: - iree_hal_cuda_pending_queue_actions_issue_cleanup(action); - break; - } - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_entry_list_node_push_front(entry, action); - iree_hal_cuda_entry_list_push(&actions->working_area.ready_worklist, - entry); - break; - } - } - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - break; - } - - iree_allocator_free(actions->host_allocator, entry); - } - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_pending_queue_actions_fail(actions, status); - iree_status_ignore(status); - } - - IREE_TRACE_ZONE_END(z0); -} - -static void iree_hal_cuda_worker_process_completion( - iree_hal_cuda_pending_queue_actions_t* actions) { - IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_cuda_completion_list_t* worklist = - &actions->completion_area.completion_list; - iree_slim_mutex_lock(&actions->action_mutex); - iree_status_t status = iree_status_clone(actions->status); - iree_slim_mutex_unlock(&actions->action_mutex); - - while (true) { - iree_hal_cuda_completion_list_node_t* entry = - iree_hal_cuda_completion_list_pop(worklist); - if (!entry) break; - - if (IREE_LIKELY(iree_status_is_ok(status))) { - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "cuEventSynchronize"); - status = IREE_CURESULT_TO_STATUS(actions->symbols, - cuEventSynchronize(entry->event)); - IREE_TRACE_ZONE_END(z1); - } - - status = - iree_status_join(status, entry->callback(status, entry->user_data)); - - if (IREE_UNLIKELY(entry->created_event)) { - status = iree_status_join( - status, IREE_CURESULT_TO_STATUS(actions->symbols, - cuEventDestroy(entry->event))); - } - iree_allocator_free(actions->host_allocator, entry); - } - - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_pending_queue_actions_fail(actions, status); - iree_status_ignore(status); - } - - IREE_TRACE_ZONE_END(z0); -} - -// The main function for the completion worker thread. -static int iree_hal_cuda_completion_execute( - iree_hal_cuda_pending_queue_actions_t* actions) { - iree_hal_cuda_completion_area_t* completion_area = &actions->completion_area; - - iree_status_t status = IREE_CURESULT_TO_STATUS( - actions->symbols, cuCtxSetCurrent(actions->cuda_context), - "cuCtxSetCurrent"); - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_pending_queue_actions_fail(actions, status); - iree_status_ignore(status); - } - - while (true) { - iree_notification_await( - &completion_area->state_notification, - (iree_condition_fn_t)iree_hal_cuda_completion_has_incoming_request, - completion_area, iree_infinite_timeout()); - - // Immediately flip the state to idle waiting if and only if the previous - // state is workload pending. We do it before processing ready list to make - // sure that we don't accidentally ignore new workload pushed after done - // ready list processing but before overwriting the state from this worker - // thread. - iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); - iree_hal_cuda_worker_process_completion(actions); - - iree_slim_mutex_lock(&actions->action_mutex); - if (IREE_UNLIKELY(actions->exit_requested && - !actions->pending_work_items_count)) { - iree_slim_mutex_unlock(&actions->action_mutex); - return 0; - } - iree_slim_mutex_unlock(&actions->action_mutex); - } - - return 0; -} - -// The main function for the ready-list processing worker thread. -static int iree_hal_cuda_worker_execute( - iree_hal_cuda_pending_queue_actions_t* actions) { - iree_hal_cuda_working_area_t* working_area = &actions->working_area; - - // Cuda stores thread-local data based on the device. Some cuda commands pull - // the device from there, and it defaults to device 0 (e.g. cuEventCreate), - // this will cause failures when using it with other devices (or streams from - // other devices). Force the correct device onto this thread. - iree_status_t status = IREE_CURESULT_TO_STATUS( - actions->symbols, cuCtxSetCurrent(actions->cuda_context), - "cuCtxSetCurrent"); - if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_cuda_pending_queue_actions_fail(actions, status); - iree_status_ignore(status); - // We can safely exit here because there are no actions in flight yet. - return -1; - } - - while (true) { - // Block waiting for incoming requests. - // - // TODO: When exit is requested with - // IREE_HAL_CUDA_WORKER_STATE_EXIT_REQUESTED - // we will return immediately causing a busy wait and hogging the CPU. - // We need to properly wait for action cleanups to be scheduled from the - // host stream callbacks. - iree_notification_await( - &working_area->state_notification, - (iree_condition_fn_t)iree_hal_cuda_worker_has_incoming_request, - working_area, iree_infinite_timeout()); - - // Immediately flip the state to idle waiting if and only if the previous - // state is workload pending. We do it before processing ready list to make - // sure that we don't accidentally ignore new workload pushed after done - // ready list processing but before overwriting the state from this worker - // thread. - iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_CUDA_WORKER_STATE_IDLE_WAITING, - iree_memory_order_release); - - iree_hal_cuda_worker_process_ready_list(actions); - - iree_slim_mutex_lock(&actions->action_mutex); - if (IREE_UNLIKELY(actions->exit_requested && - !actions->pending_work_items_count)) { - iree_slim_mutex_unlock(&actions->action_mutex); - iree_hal_cuda_pending_queue_actions_notify_completion_thread( - &actions->completion_area); - return 0; - } - iree_slim_mutex_unlock(&actions->action_mutex); - } - return 0; -} diff --git a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h b/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h deleted file mode 100644 index fa16e1fce956..000000000000 --- a/runtime/src/iree/hal/drivers/cuda/pending_queue_actions.h +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_HAL_DRIVERS_CUDA_PENDING_QUEUE_ACTIONS_H_ -#define IREE_HAL_DRIVERS_CUDA_PENDING_QUEUE_ACTIONS_H_ - -#include "iree/base/api.h" -#include "iree/base/internal/arena.h" -#include "iree/hal/api.h" -#include "iree/hal/drivers/cuda/cuda_dynamic_symbols.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// A data structure to manage pending queue actions (kernel launches and async -// allocations). -// -// This is needed in order to satisfy queue action dependencies. IREE uses HAL -// semaphore as the unified mechanism for synchronization directions including -// host to host, host to device, devie to device, and device to host. Plus, it -// allows wait before signal. These flexible capabilities are not all supported -// by CUevent objects. Therefore, we need supporting data structures to -// implement them on top of CUevent objects. Thus this pending queue actions. -// -// This buffers pending queue actions and their associated resources. It -// provides an API to advance the wait list on demand--queue actions are -// released to the GPU when all their wait semaphores are signaled past the -// desired value, or we can have a CUevent already recorded to some CUDA -// stream to wait on. -// -// Thread-safe; multiple threads may enqueue workloads. -typedef struct iree_hal_cuda_pending_queue_actions_t - iree_hal_cuda_pending_queue_actions_t; - -// Creates a pending actions queue. -iree_status_t iree_hal_cuda_pending_queue_actions_create( - const iree_hal_cuda_dynamic_symbols_t* symbols, CUdevice device, - CUcontext context, iree_arena_block_pool_t* block_pool, - iree_allocator_t host_allocator, - iree_hal_cuda_pending_queue_actions_t** out_actions); - -// Destroys the pending |actions| queue. -void iree_hal_cuda_pending_queue_actions_destroy(iree_hal_resource_t* actions); - -// Callback to execute user code after action completion but before resource -// releasing. -// -// Data behind |user_data| must remain alive before the action is released. -typedef void(IREE_API_PTR* iree_hal_cuda_pending_action_cleanup_callback_t)( - void* user_data); - -// Enqueues the given list of |command_buffers| that waits on -// |wait_semaphore_list| and signals |signal_semaphore_lsit|. -// -// |cleanup_callback|, if not NULL, will run after the action completes but -// before releasing all retained resources. -iree_status_t iree_hal_cuda_pending_queue_actions_enqueue_execution( - iree_hal_device_t* device, CUstream dispatch_stream, - iree_hal_cuda_pending_queue_actions_t* actions, - iree_hal_cuda_pending_action_cleanup_callback_t cleanup_callback, - void* callback_user_data, - const iree_hal_semaphore_list_t wait_semaphore_list, - const iree_hal_semaphore_list_t signal_semaphore_list, - iree_host_size_t command_buffer_count, - iree_hal_command_buffer_t* const* command_buffers, - iree_hal_buffer_binding_table_t const* binding_tables); - -// Tries to scan the pending actions and release ready ones to the GPU. -iree_status_t iree_hal_cuda_pending_queue_actions_issue( - iree_hal_cuda_pending_queue_actions_t* actions); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_HAL_DRIVERS_CUDA_PENDING_QUEUE_ACTIONS_H_ diff --git a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt index d20e777d14e8..f48d3e1ce69c 100644 --- a/runtime/src/iree/hal/drivers/hip/CMakeLists.txt +++ b/runtime/src/iree/hal/drivers/hip/CMakeLists.txt @@ -40,8 +40,6 @@ iree_cc_library( "native_executable.h" "nop_executable_cache.c" "nop_executable_cache.h" - "pending_queue_actions.c" - "pending_queue_actions.h" "pipeline_layout.c" "pipeline_layout.h" "rccl_channel.c" @@ -69,6 +67,7 @@ iree_cc_library( iree::hal iree::hal::utils::collective_batch iree::hal::utils::deferred_command_buffer + iree::hal::utils::deferred_work_queue iree::hal::utils::file_transfer iree::hal::utils::memory_file iree::hal::utils::resource_set diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.c b/runtime/src/iree/hal/drivers/hip/event_semaphore.c index 99705a602f4d..926eb54ce5f7 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.c +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.c @@ -26,9 +26,9 @@ typedef struct iree_hal_hip_semaphore_t { // The timepoint pool to acquire timepoint objects. iree_hal_hip_timepoint_pool_t* timepoint_pool; - // The list of pending queue actions that this semaphore need to advance on + // The list of actions that this semaphore may need to advance on // new signaled values. - iree_hal_hip_pending_queue_actions_t* pending_queue_actions; + iree_hal_deferred_work_queue_t* work_queue; // Guards value and status. We expect low contention on semaphores and since // iree_slim_mutex_t is (effectively) just a CAS this keeps things simpler @@ -57,11 +57,11 @@ static iree_hal_hip_semaphore_t* iree_hal_hip_semaphore_cast( iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_hal_hip_pending_queue_actions_t* pending_queue_actions, - iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore) { + iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, + iree_hal_semaphore_t** out_semaphore) { IREE_ASSERT_ARGUMENT(symbols); IREE_ASSERT_ARGUMENT(timepoint_pool); - IREE_ASSERT_ARGUMENT(pending_queue_actions); + IREE_ASSERT_ARGUMENT(work_queue); IREE_ASSERT_ARGUMENT(out_semaphore); IREE_TRACE_ZONE_BEGIN(z0); @@ -75,7 +75,7 @@ iree_status_t iree_hal_hip_event_semaphore_create( semaphore->host_allocator = host_allocator; semaphore->symbols = symbols; semaphore->timepoint_pool = timepoint_pool; - semaphore->pending_queue_actions = pending_queue_actions; + semaphore->work_queue = work_queue; iree_slim_mutex_initialize(&semaphore->mutex); semaphore->current_value = initial_value; semaphore->failure_status = iree_ok_status(); @@ -149,10 +149,10 @@ static iree_status_t iree_hal_hip_semaphore_signal( // Notify timepoints - note that this must happen outside the lock. iree_hal_semaphore_notify(&semaphore->base, new_value, IREE_STATUS_OK); - // Advance the pending queue actions if possible. This also must happen + // Advance the deferred work queue if possible. This also must happen // outside the lock to avoid nesting. - iree_status_t status = iree_hal_hip_pending_queue_actions_issue( - semaphore->pending_queue_actions); + iree_status_t status = + iree_hal_deferred_work_queue_issue(semaphore->work_queue); IREE_TRACE_ZONE_END(z0); return status; @@ -171,7 +171,7 @@ static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, // Try to set our local status - we only preserve the first failure so only // do this if we are going from a valid semaphore to a failed one. if (!iree_status_is_ok(semaphore->failure_status)) { - // Previous status was not OK; drop our new status. + // Previous sta-tus was not OK; drop our new status. IREE_IGNORE_ERROR(status); iree_slim_mutex_unlock(&semaphore->mutex); IREE_TRACE_ZONE_END(z0); @@ -188,10 +188,9 @@ static void iree_hal_hip_semaphore_fail(iree_hal_semaphore_t* base_semaphore, iree_hal_semaphore_notify(&semaphore->base, IREE_HAL_SEMAPHORE_FAILURE_VALUE, status_code); - // Advance the pending queue actions if possible. This also must happen + // Advance the deferred work queue if possible. This also must happen // outside the lock to avoid nesting. - status = iree_hal_hip_pending_queue_actions_issue( - semaphore->pending_queue_actions); + status = iree_hal_deferred_work_queue_issue(semaphore->work_queue); iree_status_ignore(status); IREE_TRACE_ZONE_END(z0); diff --git a/runtime/src/iree/hal/drivers/hip/event_semaphore.h b/runtime/src/iree/hal/drivers/hip/event_semaphore.h index 986068eede11..88a75e01c436 100644 --- a/runtime/src/iree/hal/drivers/hip/event_semaphore.h +++ b/runtime/src/iree/hal/drivers/hip/event_semaphore.h @@ -12,8 +12,8 @@ #include "iree/base/api.h" #include "iree/hal/api.h" #include "iree/hal/drivers/hip/dynamic_symbols.h" -#include "iree/hal/drivers/hip/pending_queue_actions.h" #include "iree/hal/drivers/hip/timepoint_pool.h" +#include "iree/hal/utils/deferred_work_queue.h" #ifdef __cplusplus extern "C" { @@ -26,14 +26,14 @@ extern "C" { // be allocated from the |timepoint_pool|. // // This semaphore is meant to be used together with a pending queue actions; it -// may advance the given |pending_queue_actions| if new values are signaled. +// may advance the given |work_queue| if new values are signaled. // // Thread-safe; multiple threads may signal/wait values on the same semaphore. iree_status_t iree_hal_hip_event_semaphore_create( uint64_t initial_value, const iree_hal_hip_dynamic_symbols_t* symbols, iree_hal_hip_timepoint_pool_t* timepoint_pool, - iree_hal_hip_pending_queue_actions_t* pending_queue_actions, - iree_allocator_t host_allocator, iree_hal_semaphore_t** out_semaphore); + iree_hal_deferred_work_queue_t* work_queue, iree_allocator_t host_allocator, + iree_hal_semaphore_t** out_semaphore); // Acquires a timepoint to signal the timeline to the given |to_value| from the // device. The underlying HIP event is written into |out_event| for interacting diff --git a/runtime/src/iree/hal/drivers/hip/hip_device.c b/runtime/src/iree/hal/drivers/hip/hip_device.c index de67c4ddb2e4..133d3f5de4c2 100644 --- a/runtime/src/iree/hal/drivers/hip/hip_device.c +++ b/runtime/src/iree/hal/drivers/hip/hip_device.c @@ -21,7 +21,6 @@ #include "iree/hal/drivers/hip/hip_allocator.h" #include "iree/hal/drivers/hip/memory_pools.h" #include "iree/hal/drivers/hip/nop_executable_cache.h" -#include "iree/hal/drivers/hip/pending_queue_actions.h" #include "iree/hal/drivers/hip/pipeline_layout.h" #include "iree/hal/drivers/hip/rccl_channel.h" #include "iree/hal/drivers/hip/rccl_dynamic_symbols.h" @@ -30,6 +29,7 @@ #include "iree/hal/drivers/hip/timepoint_pool.h" #include "iree/hal/drivers/hip/tracing.h" #include "iree/hal/utils/deferred_command_buffer.h" +#include "iree/hal/utils/deferred_work_queue.h" #include "iree/hal/utils/file_transfer.h" #include "iree/hal/utils/memory_file.h" @@ -77,7 +77,7 @@ typedef struct iree_hal_hip_device_t { // are met. It buffers submissions and allocations internally before they // are ready. This queue couples with HAL semaphores backed by iree_event_t // and hipEvent_t objects. - iree_hal_hip_pending_queue_actions_t* pending_queue_actions; + iree_hal_deferred_work_queue_t* work_queue; // Device memory pools and allocators. bool supports_memory_pools; @@ -89,6 +89,173 @@ typedef struct iree_hal_hip_device_t { } iree_hal_hip_device_t; static const iree_hal_device_vtable_t iree_hal_hip_device_vtable; +static const iree_hal_deferred_work_queue_device_interface_vtable_t + iree_hal_hip_deferred_work_queue_device_interface_vtable; + +// We put a hipEvent_t into a iree_hal_deferred_work_queue_native_event_t. +static_assert(sizeof(hipEvent_t) <= + sizeof(iree_hal_deferred_work_queue_native_event_t), + "Unexpected event size"); +typedef struct iree_hal_hip_deferred_work_queue_device_interface_t { + iree_hal_deferred_work_queue_device_interface_t base; + iree_hal_device_t* device; + hipDevice_t hip_device; + hipCtx_t hip_context; + hipStream_t dispatch_hip_stream; + iree_allocator_t host_allocator; + const iree_hal_hip_dynamic_symbols_t* hip_symbols; +} iree_hal_hip_deferred_work_queue_device_interface_t; + +static void iree_hal_hip_deferred_work_queue_device_interface_destroy( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + iree_allocator_free(device_interface->host_allocator, device_interface); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_bind_to_thread( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS( + device_interface->hip_symbols, + hipCtxSetCurrent(device_interface->hip_context), "hipCtxSetCurrent"); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_wait_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS( + device_interface->hip_symbols, + hipStreamWaitEvent(device_interface->dispatch_hip_stream, + (hipEvent_t)event, 0), + "hipStreamWaitEvent"); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_create_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t* out_event) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, + hipEventCreate((hipEvent_t*)out_event), + "hipEventCreate"); +} +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_record_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS( + device_interface->hip_symbols, + hipEventRecord((hipEvent_t)event, device_interface->dispatch_hip_stream), + "hipEventRecord"); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_synchronize_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, + hipEventSynchronize((hipEvent_t)event)); +} +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_destroy_native_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_native_event_t event) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS(device_interface->hip_symbols, + hipEventDestroy((hipEvent_t)event)); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + struct iree_hal_semaphore_t* semaphore, uint64_t value, + iree_hal_deferred_work_queue_native_event_t* out_event) { + return iree_hal_hip_event_semaphore_acquire_timepoint_device_signal( + semaphore, value, (hipEvent_t*)out_event); +} + +static bool +iree_hal_hip_deferred_work_queue_device_interface_acquire_host_wait_event( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + struct iree_hal_semaphore_t* semaphore, uint64_t value, + iree_hal_deferred_work_queue_host_device_event_t* out_event) { + return iree_hal_hip_semaphore_acquire_event_host_wait( + semaphore, value, (iree_hal_hip_event_t**)out_event); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_device_wait_on_host_event( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_deferred_work_queue_host_device_event_t wait_event) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return IREE_HIP_RESULT_TO_STATUS( + device_interface->hip_symbols, + hipStreamWaitEvent( + device_interface->dispatch_hip_stream, + iree_hal_hip_event_handle((iree_hal_hip_event_t*)wait_event), 0), + "hipStreamWaitEvent"); +} + +static void +iree_hal_hip_deferred_work_queue_device_interface_release_wait_event( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_host_device_event_t wait_event) { + iree_hal_hip_event_release(wait_event); +} + +static iree_hal_deferred_work_queue_native_event_t +iree_hal_hip_deferred_work_queue_device_interface_native_event_from_wait_event( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_host_device_event_t event) { + iree_hal_hip_event_t* wait_event = (iree_hal_hip_event_t*)event; + return iree_hal_hip_event_handle(wait_event); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_create_stream_command_buffer( + iree_hal_deferred_work_queue_device_interface_t* base_device_interface, + iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t categories, + iree_hal_command_buffer_t** out) { + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(base_device_interface); + return iree_hal_hip_device_create_stream_command_buffer( + device_interface->device, mode, categories, 0, out); +} + +static iree_status_t +iree_hal_hip_deferred_work_queue_device_interface_submit_command_buffer( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_command_buffer_t* command_buffer) { + iree_hal_hip_deferred_work_queue_device_interface_t* table = + (iree_hal_hip_deferred_work_queue_device_interface_t*)(device_interface); + iree_status_t status = iree_ok_status(); + if (iree_hal_hip_stream_command_buffer_isa(command_buffer)) { + // Stream command buffer so nothing to do but notify it was submitted. + iree_hal_hip_stream_notify_submitted_commands(command_buffer); + } else { + hipGraphExec_t exec = + iree_hal_hip_graph_command_buffer_handle(command_buffer); + status = IREE_HIP_RESULT_TO_STATUS( + table->hip_symbols, hipGraphLaunch(exec, table->dispatch_hip_stream)); + if (IREE_LIKELY(iree_status_is_ok(status))) { + iree_hal_hip_graph_tracing_notify_submitted_commands(command_buffer); + } + } + return status; +} static iree_hal_hip_device_t* iree_hal_hip_device_cast( iree_hal_device_t* base_value) { @@ -154,9 +321,26 @@ static iree_status_t iree_hal_hip_device_create_internal( device->hip_dispatch_stream = dispatch_stream; device->host_allocator = host_allocator; - iree_status_t status = iree_hal_hip_pending_queue_actions_create( - symbols, hip_device, &device->block_pool, host_allocator, - &device->pending_queue_actions); + iree_hal_hip_deferred_work_queue_device_interface_t* device_interface; + iree_status_t status = iree_allocator_malloc( + host_allocator, + sizeof(iree_hal_hip_deferred_work_queue_device_interface_t), + (void**)&device_interface); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { + iree_hal_device_release((iree_hal_device_t*)device); + return status; + } + device_interface->base.vtable = + &iree_hal_hip_deferred_work_queue_device_interface_vtable; + device_interface->hip_context = context; + device_interface->hip_symbols = symbols; + device_interface->device = (iree_hal_device_t*)device; + device_interface->hip_device = hip_device; + device_interface->dispatch_hip_stream = dispatch_stream; + device_interface->host_allocator = host_allocator; + status = iree_hal_deferred_work_queue_create( + (iree_hal_deferred_work_queue_device_interface_t*)device_interface, + &device->block_pool, host_allocator, &device->work_queue); // Enable tracing for the (currently only) stream - no-op if disabled. if (iree_status_is_ok(status) && device->params.stream_tracing) { @@ -298,8 +482,7 @@ static void iree_hal_hip_device_destroy(iree_hal_device_t* base_device) { IREE_TRACE_ZONE_BEGIN(z0); // Destroy the pending workload queue. - iree_hal_hip_pending_queue_actions_destroy( - (iree_hal_resource_t*)device->pending_queue_actions); + iree_hal_deferred_work_queue_destroy(device->work_queue); // There should be no more buffers live that use the allocator. iree_hal_allocator_release(device->device_allocator); @@ -622,7 +805,7 @@ static iree_status_t iree_hal_hip_device_create_semaphore( iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); return iree_hal_hip_event_semaphore_create( initial_value, device->hip_symbols, device->timepoint_pool, - device->pending_queue_actions, device->host_allocator, out_semaphore); + device->work_queue, device->host_allocator, out_semaphore); } static iree_hal_semaphore_compatibility_t @@ -767,15 +950,13 @@ static iree_status_t iree_hal_hip_device_queue_execute( iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_TRACE_ZONE_BEGIN(z0); - iree_status_t status = iree_hal_hip_pending_queue_actions_enqueue_execution( - base_device, device->hip_dispatch_stream, device->pending_queue_actions, - iree_hal_hip_device_collect_tracing_context, device->tracing_context, - wait_semaphore_list, signal_semaphore_list, command_buffer_count, - command_buffers, binding_tables); + iree_status_t status = iree_hal_deferred_work_queue_enque( + device->work_queue, iree_hal_hip_device_collect_tracing_context, + device->tracing_context, wait_semaphore_list, signal_semaphore_list, + command_buffer_count, command_buffers, binding_tables); if (iree_status_is_ok(status)) { - // Try to advance the pending workload queue. - status = - iree_hal_hip_pending_queue_actions_issue(device->pending_queue_actions); + // Try to advance the deferred work queue. + status = iree_hal_deferred_work_queue_issue(device->work_queue); } IREE_TRACE_ZONE_END(z0); @@ -786,9 +967,8 @@ static iree_status_t iree_hal_hip_device_queue_flush( iree_hal_device_t* base_device, iree_hal_queue_affinity_t queue_affinity) { iree_hal_hip_device_t* device = iree_hal_hip_device_cast(base_device); IREE_TRACE_ZONE_BEGIN(z0); - // Try to advance the pending workload queue. - iree_status_t status = - iree_hal_hip_pending_queue_actions_issue(device->pending_queue_actions); + // Try to advance the deferred work queue. + iree_status_t status = iree_hal_deferred_work_queue_issue(device->work_queue); IREE_TRACE_ZONE_END(z0); return status; } @@ -851,3 +1031,34 @@ static const iree_hal_device_vtable_t iree_hal_hip_device_vtable = { .profiling_flush = iree_hal_hip_device_profiling_flush, .profiling_end = iree_hal_hip_device_profiling_end, }; + +static const iree_hal_deferred_work_queue_device_interface_vtable_t + iree_hal_hip_deferred_work_queue_device_interface_vtable = { + .destroy = iree_hal_hip_deferred_work_queue_device_interface_destroy, + .bind_to_thread = + iree_hal_hip_deferred_work_queue_device_interface_bind_to_thread, + .wait_native_event = + iree_hal_hip_deferred_work_queue_device_interface_wait_native_event, + .create_native_event = + iree_hal_hip_deferred_work_queue_device_interface_create_native_event, + .record_native_event = + iree_hal_hip_deferred_work_queue_device_interface_record_native_event, + .synchronize_native_event = + iree_hal_hip_deferred_work_queue_device_interface_synchronize_native_event, + .destroy_native_event = + iree_hal_hip_deferred_work_queue_device_interface_destroy_native_event, + .semaphore_acquire_timepoint_device_signal_native_event = + iree_hal_hip_deferred_work_queue_device_interface_semaphore_acquire_timepoint_device_signal_native_event, + .acquire_host_wait_event = + iree_hal_hip_deferred_work_queue_device_interface_acquire_host_wait_event, + .device_wait_on_host_event = + iree_hal_hip_deferred_work_queue_device_interface_device_wait_on_host_event, + .release_wait_event = + iree_hal_hip_deferred_work_queue_device_interface_release_wait_event, + .native_event_from_wait_event = + iree_hal_hip_deferred_work_queue_device_interface_native_event_from_wait_event, + .create_stream_command_buffer = + iree_hal_hip_deferred_work_queue_device_interface_create_stream_command_buffer, + .submit_command_buffer = + iree_hal_hip_deferred_work_queue_device_interface_submit_command_buffer, +}; diff --git a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.h b/runtime/src/iree/hal/drivers/hip/pending_queue_actions.h deleted file mode 100644 index d1253721cf3b..000000000000 --- a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.h +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef IREE_HAL_DRIVERS_HIP_PENDING_QUEUE_ACTIONS_H_ -#define IREE_HAL_DRIVERS_HIP_PENDING_QUEUE_ACTIONS_H_ - -#include "iree/base/api.h" -#include "iree/base/internal/arena.h" -#include "iree/hal/api.h" -#include "iree/hal/drivers/hip/dynamic_symbols.h" - -#ifdef __cplusplus -extern "C" { -#endif // __cplusplus - -// A data structure to manage pending queue actions (kernel launches and async -// allocations). -// -// This is needed in order to satisfy queue action dependencies. IREE uses HAL -// semaphore as the unified mechanism for synchronization directions including -// host to host, host to device, devie to device, and device to host. Plus, it -// allows wait before signal. These flexible capabilities are not all supported -// by hipEvent_t objects. Therefore, we need supporting data structures to -// implement them on top of hipEvent_t objects. Thus this pending queue actions. -// -// This buffers pending queue actions and their associated resources. It -// provides an API to advance the wait list on demand--queue actions are -// released to the GPU when all their wait semaphores are signaled past the -// desired value, or we can have a hipEvent_t already recorded to some HIP -// stream to wait on. -// -// Thread-safe; multiple threads may enqueue workloads. -typedef struct iree_hal_hip_pending_queue_actions_t - iree_hal_hip_pending_queue_actions_t; - -// Creates a pending actions queue. -iree_status_t iree_hal_hip_pending_queue_actions_create( - const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, - iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, - iree_hal_hip_pending_queue_actions_t** out_actions); - -// Destroys the pending |actions| queue. -void iree_hal_hip_pending_queue_actions_destroy(iree_hal_resource_t* actions); - -// Callback to execute user code after action completion but before resource -// releasing. -// -// Data behind |user_data| must remain alive before the action is released. -typedef void(IREE_API_PTR* iree_hal_hip_pending_action_cleanup_callback_t)( - void* user_data); - -// Enqueues the given list of |command_buffers| that waits on -// |wait_semaphore_list| and signals |signal_semaphore_lsit|. -// -// |cleanup_callback|, if not NULL, will run after the action completes but -// before releasing all retained resources. -iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution( - iree_hal_device_t* device, hipStream_t dispatch_stream, - iree_hal_hip_pending_queue_actions_t* actions, - iree_hal_hip_pending_action_cleanup_callback_t cleanup_callback, - void* callback_user_data, - const iree_hal_semaphore_list_t wait_semaphore_list, - const iree_hal_semaphore_list_t signal_semaphore_list, - iree_host_size_t command_buffer_count, - iree_hal_command_buffer_t* const* command_buffers, - iree_hal_buffer_binding_table_t const* binding_tables); - -// Tries to scan the pending actions and release ready ones to the GPU. -iree_status_t iree_hal_hip_pending_queue_actions_issue( - iree_hal_hip_pending_queue_actions_t* actions); - -#ifdef __cplusplus -} // extern "C" -#endif // __cplusplus - -#endif // IREE_HAL_DRIVERS_HIP_PENDING_QUEUE_ACTIONS_H_ diff --git a/runtime/src/iree/hal/utils/BUILD.bazel b/runtime/src/iree/hal/utils/BUILD.bazel index 395fc7f93858..38f170c5f159 100644 --- a/runtime/src/iree/hal/utils/BUILD.bazel +++ b/runtime/src/iree/hal/utils/BUILD.bazel @@ -201,3 +201,19 @@ iree_runtime_cc_test( "//runtime/src/iree/testing:gtest_main", ], ) + +iree_runtime_cc_library( + name = "deferred_work_queue", + srcs = ["deferred_work_queue.c"], + hdrs = ["deferred_work_queue.h"], + deps = [ + ":deferred_command_buffer", + ":resource_set", + ":semaphore_base", + "//runtime/src/iree/base", + "//runtime/src/iree/base/internal:arena", + "//runtime/src/iree/base/internal:synchronization", + "//runtime/src/iree/base/internal:threading", + "//runtime/src/iree/hal", + ], +) diff --git a/runtime/src/iree/hal/utils/CMakeLists.txt b/runtime/src/iree/hal/utils/CMakeLists.txt index 78263704dd2d..3da6140730dd 100644 --- a/runtime/src/iree/hal/utils/CMakeLists.txt +++ b/runtime/src/iree/hal/utils/CMakeLists.txt @@ -238,4 +238,23 @@ iree_cc_test( iree::testing::gtest_main ) +iree_cc_library( + NAME + deferred_work_queue + HDRS + "deferred_work_queue.h" + SRCS + "deferred_work_queue.c" + DEPS + ::deferred_command_buffer + ::resource_set + ::semaphore_base + iree::base + iree::base::internal::arena + iree::base::internal::synchronization + iree::base::internal::threading + iree::hal + PUBLIC +) + ### BAZEL_TO_CMAKE_PRESERVES_ALL_CONTENT_BELOW_THIS_LINE ### diff --git a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c b/runtime/src/iree/hal/utils/deferred_work_queue.c similarity index 57% rename from runtime/src/iree/hal/drivers/hip/pending_queue_actions.c rename to runtime/src/iree/hal/utils/deferred_work_queue.c index 88a2e830537d..76ed80f7414f 100644 --- a/runtime/src/iree/hal/drivers/hip/pending_queue_actions.c +++ b/runtime/src/iree/hal/utils/deferred_work_queue.c @@ -4,47 +4,37 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "iree/hal/drivers/hip/pending_queue_actions.h" +#include "iree/hal/utils/deferred_work_queue.h" #include #include #include "iree/base/api.h" -#include "iree/base/assert.h" #include "iree/base/internal/arena.h" -#include "iree/base/internal/atomic_slist.h" -#include "iree/base/internal/atomics.h" #include "iree/base/internal/synchronization.h" #include "iree/base/internal/threading.h" #include "iree/hal/api.h" -#include "iree/hal/drivers/hip/dynamic_symbols.h" -#include "iree/hal/drivers/hip/event_pool.h" -#include "iree/hal/drivers/hip/event_semaphore.h" -#include "iree/hal/drivers/hip/graph_command_buffer.h" -#include "iree/hal/drivers/hip/hip_device.h" -#include "iree/hal/drivers/hip/status_util.h" -#include "iree/hal/drivers/hip/stream_command_buffer.h" #include "iree/hal/utils/deferred_command_buffer.h" #include "iree/hal/utils/resource_set.h" -// The maximal number of hipEvent_t objects a command buffer can wait. -#define IREE_HAL_HIP_MAX_WAIT_EVENT_COUNT 32 +// The maximal number of events a command buffer can wait on. +#define IREE_HAL_MAX_WAIT_EVENT_COUNT 32 //===----------------------------------------------------------------------===// // Queue action //===----------------------------------------------------------------------===// -typedef enum iree_hal_hip_queue_action_kind_e { - IREE_HAL_HIP_QUEUE_ACTION_TYPE_EXECUTION, +typedef enum iree_hal_deferred_work_queue_action_kind_e { + IREE_HAL_QUEUE_ACTION_TYPE_EXECUTION, // TODO: Add support for queue alloca and dealloca. -} iree_hal_hip_queue_action_kind_t; +} iree_hal_deferred_work_queue_action_kind_t; -typedef enum iree_hal_hip_queue_action_state_e { +typedef enum iree_hal_deferred_work_queue_action_state_e { // The current action is active as waiting for or under execution. - IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE, + IREE_HAL_QUEUE_ACTION_STATE_ALIVE, // The current action is done execution and waiting for destruction. - IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE, -} iree_hal_hip_queue_action_state_t; + IREE_HAL_QUEUE_ACTION_STATE_ZOMBIE, +} iree_hal_deferred_work_queue_action_state_t; // How many work items must complete in order for an action to complete. // We keep track of the remaining work for an action so we don't exit worker @@ -53,32 +43,33 @@ typedef enum iree_hal_hip_queue_action_state_e { // +1 for cleaning up a zombie action. static const iree_host_size_t total_work_items_to_complete_an_action = 2; -// A pending queue action. -// +// A work queue action. // Note that this struct does not have internal synchronization; it's expected -// to work together with the pending action queue, which synchronizes accesses. -typedef struct iree_hal_hip_queue_action_t { +// to work together with the deferred work queue, which synchronizes accesses. +typedef struct iree_hal_deferred_work_queue_action_t { // Intrusive doubly-linked list next entry pointer. - struct iree_hal_hip_queue_action_t* next; + struct iree_hal_deferred_work_queue_action_t* next; // Intrusive doubly-linked list previous entry pointer. - struct iree_hal_hip_queue_action_t* prev; + struct iree_hal_deferred_work_queue_action_t* prev; - // The owning pending actions queue. We use its allocators and pools. + // The owning deferred work queue. We use its allocators and pools. // Retained to make sure it outlives the current action. - iree_hal_hip_pending_queue_actions_t* owning_actions; + iree_hal_deferred_work_queue_t* owning_actions; // The current state of this action. When an action is initially created it // will be alive and enqueued to wait for releasing to the GPU. After done // execution, it will be flipped into zombie state and enqueued again for // destruction. - iree_hal_hip_queue_action_state_t state; + iree_hal_deferred_work_queue_action_state_t state; // The callback to run after completing this action and before freeing // all resources. Can be NULL. - iree_hal_hip_pending_action_cleanup_callback_t cleanup_callback; + iree_hal_deferred_work_queue_cleanup_callback_t cleanup_callback; // User data to pass into the callback. void* callback_user_data; - iree_hal_hip_queue_action_kind_t kind; + iree_hal_deferred_work_queue_device_interface_t* device_interface; + + iree_hal_deferred_work_queue_action_kind_t kind; union { struct { iree_host_size_t count; @@ -87,13 +78,6 @@ typedef struct iree_hal_hip_queue_action_t { } execution; } payload; - // The device from which to allocate HIP stream-based command buffers for - // applying deferred command buffers. - iree_hal_device_t* device; - - // The stream to launch main GPU workload. - hipStream_t dispatch_hip_stream; - // Resource set to retain all associated resources by the payload. iree_hal_resource_set_t* resource_set; @@ -103,46 +87,49 @@ typedef struct iree_hal_hip_queue_action_t { iree_hal_semaphore_list_t signal_semaphore_list; // Scratch fields for analyzing whether actions are ready to issue. - iree_hal_hip_event_t* events[IREE_HAL_HIP_MAX_WAIT_EVENT_COUNT]; + iree_hal_deferred_work_queue_host_device_event_t + wait_events[IREE_HAL_MAX_WAIT_EVENT_COUNT]; iree_host_size_t event_count; // Whether the current action is still not ready for releasing to the GPU. bool is_pending; -} iree_hal_hip_queue_action_t; +} iree_hal_deferred_work_queue_action_t; -static void iree_hal_hip_queue_action_fail_locked( - iree_hal_hip_queue_action_t* action, iree_status_t status); +static void iree_hal_deferred_work_queue_action_fail_locked( + iree_hal_deferred_work_queue_action_t* action, iree_status_t status); -static void iree_hal_hip_queue_action_clear_events( - iree_hal_hip_queue_action_t* action) { +static void iree_hal_deferred_work_queue_action_clear_events( + iree_hal_deferred_work_queue_action_t* action) { for (iree_host_size_t i = 0; i < action->event_count; ++i) { - iree_hal_hip_event_release(action->events[i]); + action->device_interface->vtable->release_wait_event( + action->device_interface, action->wait_events[i]); } action->event_count = 0; } -static void iree_hal_hip_queue_action_destroy( - iree_hal_hip_queue_action_t* action); +static void iree_hal_deferred_work_queue_action_destroy( + iree_hal_deferred_work_queue_action_t* action); //===----------------------------------------------------------------------===// // Queue action list //===----------------------------------------------------------------------===// -typedef struct iree_hal_hip_queue_action_list_t { - iree_hal_hip_queue_action_t* head; - iree_hal_hip_queue_action_t* tail; -} iree_hal_hip_queue_action_list_t; +typedef struct iree_hal_deferred_work_queue_action_list_t { + iree_hal_deferred_work_queue_action_t* head; + iree_hal_deferred_work_queue_action_t* tail; +} iree_hal_deferred_work_queue_action_list_t; // Returns true if the action list is empty. -static inline bool iree_hal_hip_queue_action_list_is_empty( - const iree_hal_hip_queue_action_list_t* list) { +static inline bool iree_hal_deferred_work_queue_action_list_is_empty( + const iree_hal_deferred_work_queue_action_list_t* list) { return list->head == NULL; } -static iree_hal_hip_queue_action_t* iree_hal_hip_queue_action_list_pop_front( - iree_hal_hip_queue_action_list_t* list) { +static iree_hal_deferred_work_queue_action_t* +iree_hal_deferred_work_queue_action_list_pop_front( + iree_hal_deferred_work_queue_action_list_t* list) { IREE_ASSERT(list->head && list->tail); - iree_hal_hip_queue_action_t* action = list->head; + iree_hal_deferred_work_queue_action_t* action = list->head; IREE_ASSERT(!action->prev); list->head = action->next; if (action->next) { @@ -157,9 +144,9 @@ static iree_hal_hip_queue_action_t* iree_hal_hip_queue_action_list_pop_front( } // Pushes |action| on to the end of the given action |list|. -static void iree_hal_hip_queue_action_list_push_back( - iree_hal_hip_queue_action_list_t* list, - iree_hal_hip_queue_action_t* action) { +static void iree_hal_deferred_work_queue_action_list_push_back( + iree_hal_deferred_work_queue_action_list_t* list, + iree_hal_deferred_work_queue_action_t* action) { IREE_ASSERT(!action->next && !action->prev); if (list->tail) { list->tail->next = action; @@ -171,9 +158,9 @@ static void iree_hal_hip_queue_action_list_push_back( } // Takes all actions from |available_list| and moves them into |ready_list|. -static void iree_hal_hip_queue_action_list_take_all( - iree_hal_hip_queue_action_list_t* available_list, - iree_hal_hip_queue_action_list_t* ready_list) { +static void iree_hal_deferred_work_queue_action_list_take_all( + iree_hal_deferred_work_queue_action_list_t* available_list, + iree_hal_deferred_work_queue_action_list_t* ready_list) { IREE_ASSERT_NE(available_list, ready_list); ready_list->head = available_list->head; ready_list->tail = available_list->tail; @@ -181,11 +168,11 @@ static void iree_hal_hip_queue_action_list_take_all( available_list->tail = NULL; } -static void iree_hal_hip_queue_action_list_destroy( - iree_hal_hip_queue_action_t* head_action) { +static void iree_hal_deferred_work_queue_action_list_destroy( + iree_hal_deferred_work_queue_action_t* head_action) { while (head_action) { - iree_hal_hip_queue_action_t* next_action = head_action->next; - iree_hal_hip_queue_action_destroy(head_action); + iree_hal_deferred_work_queue_action_t* next_action = head_action->next; + iree_hal_deferred_work_queue_action_destroy(head_action); head_action = next_action; } } @@ -194,22 +181,25 @@ static void iree_hal_hip_queue_action_list_destroy( // Ready-list processing //===----------------------------------------------------------------------===// -// Ready action atomic slist entry struct. -typedef struct iree_hal_hip_entry_list_node_t { - iree_hal_hip_queue_action_t* ready_list_head; - struct iree_hal_hip_entry_list_node_t* next; -} iree_hal_hip_entry_list_node_t; +// Ready action entry struct. +typedef struct iree_hal_deferred_work_queue_entry_list_node_t { + iree_hal_deferred_work_queue_action_t* ready_list_head; + struct iree_hal_deferred_work_queue_entry_list_node_t* next; +} iree_hal_deferred_work_queue_entry_list_node_t; -typedef struct iree_hal_hip_entry_list_t { +typedef struct iree_hal_deferred_work_queue_entry_list_t { iree_slim_mutex_t guard_mutex; - iree_hal_hip_entry_list_node_t* head IREE_GUARDED_BY(guard_mutex); - iree_hal_hip_entry_list_node_t* tail IREE_GUARDED_BY(guard_mutex); -} iree_hal_hip_entry_list_t; + iree_hal_deferred_work_queue_entry_list_node_t* head + IREE_GUARDED_BY(guard_mutex); + iree_hal_deferred_work_queue_entry_list_node_t* tail + IREE_GUARDED_BY(guard_mutex); +} iree_hal_deferred_work_queue_entry_list_t; -static iree_hal_hip_entry_list_node_t* iree_hal_hip_entry_list_pop( - iree_hal_hip_entry_list_t* list) { - iree_hal_hip_entry_list_node_t* out = NULL; +static iree_hal_deferred_work_queue_entry_list_node_t* +iree_hal_deferred_work_queue_entry_list_pop( + iree_hal_deferred_work_queue_entry_list_t* list) { + iree_hal_deferred_work_queue_entry_list_node_t* out = NULL; iree_slim_mutex_lock(&list->guard_mutex); if (list->head) { out = list->head; @@ -222,8 +212,9 @@ static iree_hal_hip_entry_list_node_t* iree_hal_hip_entry_list_pop( return out; } -void iree_hal_hip_entry_list_push(iree_hal_hip_entry_list_t* list, - iree_hal_hip_entry_list_node_t* next) { +void iree_hal_deferred_work_queue_entry_list_push( + iree_hal_deferred_work_queue_entry_list_t* list, + iree_hal_deferred_work_queue_entry_list_node_t* next) { iree_slim_mutex_lock(&list->guard_mutex); next->next = NULL; if (list->tail) { @@ -236,48 +227,52 @@ void iree_hal_hip_entry_list_push(iree_hal_hip_entry_list_t* list, iree_slim_mutex_unlock(&list->guard_mutex); } -static void iree_hal_hip_ready_action_list_deinitialize( - iree_hal_hip_entry_list_t* list, iree_allocator_t host_allocator) { - iree_hal_hip_entry_list_node_t* head = list->head; +static void iree_hal_deferred_work_queue_ready_action_list_deinitialize( + iree_hal_deferred_work_queue_entry_list_t* list, + iree_allocator_t host_allocator) { + iree_hal_deferred_work_queue_entry_list_node_t* head = list->head; while (head) { if (!head) break; - iree_hal_hip_queue_action_list_destroy(head->ready_list_head); + iree_hal_deferred_work_queue_action_list_destroy(head->ready_list_head); list->head = head->next; iree_allocator_free(host_allocator, head); } iree_slim_mutex_deinitialize(&list->guard_mutex); } -static void iree_hal_hip_ready_action_list_initialize( - iree_hal_hip_entry_list_t* list) { +static void iree_hal_deferred_work_queue_ready_action_list_initialize( + iree_hal_deferred_work_queue_entry_list_t* list) { list->head = NULL; list->tail = NULL; iree_slim_mutex_initialize(&list->guard_mutex); } -// Ready action atomic slist entry struct. -typedef struct iree_hal_hip_completion_list_node_t { +// Ready action entry struct. +typedef struct iree_hal_deferred_work_queue_completion_list_node_t { // The callback and user data for that callback. To be called // when the associated event has completed. iree_status_t (*callback)(iree_status_t, void* user_data); void* user_data; // The event to wait for on the completion thread. - hipEvent_t event; + iree_hal_deferred_work_queue_native_event_t native_event; // If this event was created just for the completion thread, and therefore // needs to be cleaned up. bool created_event; - struct iree_hal_hip_completion_list_node_t* next; -} iree_hal_hip_completion_list_node_t; + struct iree_hal_deferred_work_queue_completion_list_node_t* next; +} iree_hal_deferred_work_queue_completion_list_node_t; -typedef struct iree_hal_hip_completion_list_t { +typedef struct iree_hal_deferred_work_queue_completion_list_t { iree_slim_mutex_t guard_mutex; - iree_hal_hip_completion_list_node_t* head IREE_GUARDED_BY(guard_mutex); - iree_hal_hip_completion_list_node_t* tail IREE_GUARDED_BY(guard_mutex); -} iree_hal_hip_completion_list_t; - -static iree_hal_hip_completion_list_node_t* iree_hal_hip_completion_list_pop( - iree_hal_hip_completion_list_t* list) { - iree_hal_hip_completion_list_node_t* out = NULL; + iree_hal_deferred_work_queue_completion_list_node_t* head + IREE_GUARDED_BY(guard_mutex); + iree_hal_deferred_work_queue_completion_list_node_t* tail + IREE_GUARDED_BY(guard_mutex); +} iree_hal_deferred_work_queue_completion_list_t; + +static iree_hal_deferred_work_queue_completion_list_node_t* +iree_hal_deferred_work_queue_completion_list_pop( + iree_hal_deferred_work_queue_completion_list_t* list) { + iree_hal_deferred_work_queue_completion_list_node_t* out = NULL; iree_slim_mutex_lock(&list->guard_mutex); if (list->head) { out = list->head; @@ -290,9 +285,9 @@ static iree_hal_hip_completion_list_node_t* iree_hal_hip_completion_list_pop( return out; } -void iree_hal_hip_completion_list_push( - iree_hal_hip_completion_list_t* list, - iree_hal_hip_completion_list_node_t* next) { +void iree_hal_deferred_work_queue_completion_list_push( + iree_hal_deferred_work_queue_completion_list_t* list, + iree_hal_deferred_work_queue_completion_list_node_t* next) { iree_slim_mutex_lock(&list->guard_mutex); next->next = NULL; if (list->tail) { @@ -305,21 +300,22 @@ void iree_hal_hip_completion_list_push( iree_slim_mutex_unlock(&list->guard_mutex); } -static void iree_hal_hip_completion_list_initialize( - iree_hal_hip_completion_list_t* list) { +static void iree_hal_deferred_work_queue_completion_list_initialize( + iree_hal_deferred_work_queue_completion_list_t* list) { list->head = NULL; list->tail = NULL; iree_slim_mutex_initialize(&list->guard_mutex); } -static void iree_hal_hip_completion_list_deinitialize( - iree_hal_hip_completion_list_t* list, - const iree_hal_hip_dynamic_symbols_t* symbols, +static void iree_hal_deferred_work_queue_completion_list_deinitialize( + iree_hal_deferred_work_queue_completion_list_t* list, + iree_hal_deferred_work_queue_device_interface_t* device_interface, iree_allocator_t host_allocator) { - iree_hal_hip_completion_list_node_t* head = list->head; + iree_hal_deferred_work_queue_completion_list_node_t* head = list->head; while (head) { if (head->created_event) { - IREE_HIP_IGNORE_ERROR(symbols, hipEventDestroy(head->event)); + device_interface->vtable->destroy_native_event(device_interface, + head->native_event); } list->head = list->head->next; iree_allocator_free(host_allocator, head); @@ -327,11 +323,12 @@ static void iree_hal_hip_completion_list_deinitialize( iree_slim_mutex_deinitialize(&list->guard_mutex); } -static iree_hal_hip_queue_action_t* iree_hal_hip_entry_list_node_pop_front( - iree_hal_hip_entry_list_node_t* list) { +static iree_hal_deferred_work_queue_action_t* +iree_hal_deferred_work_queue_entry_list_node_pop_front( + iree_hal_deferred_work_queue_entry_list_node_t* list) { IREE_ASSERT(list->ready_list_head); - iree_hal_hip_queue_action_t* action = list->ready_list_head; + iree_hal_deferred_work_queue_action_t* action = list->ready_list_head; IREE_ASSERT(!action->prev); list->ready_list_head = action->next; if (action->next) { @@ -342,12 +339,12 @@ static iree_hal_hip_queue_action_t* iree_hal_hip_entry_list_node_pop_front( return action; } -static void iree_hal_hip_entry_list_node_push_front( - iree_hal_hip_entry_list_node_t* entry, - iree_hal_hip_queue_action_t* action) { +static void iree_hal_deferred_work_queue_entry_list_node_push_front( + iree_hal_deferred_work_queue_entry_list_node_t* entry, + iree_hal_deferred_work_queue_action_t* action) { IREE_ASSERT(!action->next && !action->prev); - iree_hal_hip_queue_action_t* head = entry->ready_list_head; + iree_hal_deferred_work_queue_action_t* head = entry->ready_list_head; entry->ready_list_head = action; if (head) { action->next = head; @@ -360,16 +357,16 @@ static void iree_hal_hip_entry_list_node_push_front( // States in the list has increasing priorities--meaning normally ones appearing // earlier can overwrite ones appearing later without checking; but not the // reverse order. -typedef enum iree_hal_hip_worker_state_e { - IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING = 0, // Worker to any thread - IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING = 1, // Any to worker thread -} iree_hal_hip_worker_state_t; +typedef enum iree_hal_deferred_work_queue_worker_state_e { + IREE_HAL_WORKER_STATE_IDLE_WAITING = 0, // Worker to any thread + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING = 1, // Any to worker thread +} iree_hal_deferred_work_queue_worker_state_t; // The data structure needed by a ready-list processing worker thread to issue // ready actions to the GPU. // // This data structure is shared between the parent thread, which owns the -// whole pending actions queue, and the worker thread; so proper synchronization +// whole deferred work queue, and the worker thread; so proper synchronization // is needed to touch it from both sides. // // The parent thread should push a list of ready actions to ready_worklist, @@ -377,78 +374,80 @@ typedef enum iree_hal_hip_worker_state_e { // The worker thread waits on the state_notification and checks worker_state, // and pops from the ready_worklist to process. The worker thread also monitors // worker_state and stops processing if requested by the parent thread. -typedef struct iree_hal_hip_working_area_t { +typedef struct iree_hal_deferred_work_queue_working_area_t { // Notification from the parent thread to request worker state changes. iree_notification_t state_notification; - iree_hal_hip_entry_list_t ready_worklist; // atomic - iree_atomic_int32_t worker_state; // atomic -} iree_hal_hip_working_area_t; + iree_hal_deferred_work_queue_entry_list_t ready_worklist; // atomic + iree_atomic_int32_t worker_state; // atomic +} iree_hal_deferred_work_queue_working_area_t; // This data structure is shared by the parent thread. It is responsible // for dispatching callbacks when work items complete. -// This replaces the use of hipLaunchHostFunc, which causes the stream to block -// and wait for the CPU work to complete. It also picks up completed -// events with significantly less latency than hipLaunchHostFunc. -typedef struct iree_hal_hip_completion_area_t { +// This replaces the use of Launch Host Function APIs, which cause +// streams to block and wait for the CPU work to complete. +// It also picks up completed events with significantly less latency than +// Launch Host Function APIs. +typedef struct iree_hal_deferred_work_queue_completion_area_t { // Notification from the parent thread to request completion state changes. iree_notification_t state_notification; - iree_hal_hip_completion_list_t completion_list; // atomic - iree_atomic_int32_t worker_state; // atomic -} iree_hal_hip_completion_area_t; - -static void iree_hal_hip_working_area_initialize( - iree_allocator_t host_allocator, hipDevice_t device, - const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_working_area_t* working_area) { + iree_hal_deferred_work_queue_completion_list_t completion_list; // atomic + iree_atomic_int32_t worker_state; // atomic +} iree_hal_deferred_work_queue_completion_area_t; + +static void iree_hal_deferred_work_queue_working_area_initialize( + iree_allocator_t host_allocator, + const iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_working_area_t* working_area) { iree_notification_initialize(&working_area->state_notification); - iree_hal_hip_ready_action_list_deinitialize(&working_area->ready_worklist, - host_allocator); + iree_hal_deferred_work_queue_ready_action_list_deinitialize( + &working_area->ready_worklist, host_allocator); iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING, + IREE_HAL_WORKER_STATE_IDLE_WAITING, iree_memory_order_release); } -static void iree_hal_hip_working_area_deinitialize( - iree_hal_hip_working_area_t* working_area, +static void iree_hal_deferred_work_queue_working_area_deinitialize( + iree_hal_deferred_work_queue_working_area_t* working_area, iree_allocator_t host_allocator) { - iree_hal_hip_ready_action_list_deinitialize(&working_area->ready_worklist, - host_allocator); + iree_hal_deferred_work_queue_ready_action_list_deinitialize( + &working_area->ready_worklist, host_allocator); iree_notification_deinitialize(&working_area->state_notification); } -static void iree_hal_hip_completion_area_initialize( - iree_allocator_t host_allocator, hipDevice_t device, - const iree_hal_hip_dynamic_symbols_t* symbols, - iree_hal_hip_completion_area_t* completion_area) { +static void iree_hal_deferred_work_queue_completion_area_initialize( + iree_allocator_t host_allocator, + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_completion_area_t* completion_area) { iree_notification_initialize(&completion_area->state_notification); - iree_hal_hip_completion_list_initialize(&completion_area->completion_list); + iree_hal_deferred_work_queue_completion_list_initialize( + &completion_area->completion_list); iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING, + IREE_HAL_WORKER_STATE_IDLE_WAITING, iree_memory_order_release); } -static void iree_hal_hip_completion_area_deinitialize( - iree_hal_hip_completion_area_t* completion_area, - const iree_hal_hip_dynamic_symbols_t* symbols, +static void iree_hal_deferred_work_queue_completion_area_deinitialize( + iree_hal_deferred_work_queue_completion_area_t* completion_area, + iree_hal_deferred_work_queue_device_interface_t* device_interface, iree_allocator_t host_allocator) { - iree_hal_hip_completion_list_deinitialize(&completion_area->completion_list, - symbols, host_allocator); + iree_hal_deferred_work_queue_completion_list_deinitialize( + &completion_area->completion_list, device_interface, host_allocator); iree_notification_deinitialize(&completion_area->state_notification); } // The main function for the ready-list processing worker thread. -static int iree_hal_hip_worker_execute( - iree_hal_hip_pending_queue_actions_t* actions); +static int iree_hal_deferred_work_queue_worker_execute( + iree_hal_deferred_work_queue_t* actions); -static int iree_hal_hip_completion_execute( - iree_hal_hip_pending_queue_actions_t* actions); +static int iree_hal_deferred_work_queue_completion_execute( + iree_hal_deferred_work_queue_t* actions); //===----------------------------------------------------------------------===// -// Pending queue actions +// Deferred work queue //===----------------------------------------------------------------------===// -struct iree_hal_hip_pending_queue_actions_t { +struct iree_hal_deferred_work_queue_t { // Abstract resource used for injecting reference counting and vtable; // must be at offset 0. iree_hal_resource_t resource; @@ -458,14 +457,15 @@ struct iree_hal_hip_pending_queue_actions_t { // The block pool to allocate resource sets from. iree_arena_block_pool_t* block_pool; - // The symbols used to create and destroy hipEvent_t objects. - const iree_hal_hip_dynamic_symbols_t* symbols; + // The device interface used to interact with the native driver. + iree_hal_deferred_work_queue_device_interface_t* device_interface; // Non-recursive mutex guarding access. iree_slim_mutex_t action_mutex; - // The double-linked list of pending actions. - iree_hal_hip_queue_action_list_t action_list IREE_GUARDED_BY(action_mutex); + // The double-linked list of deferred work. + iree_hal_deferred_work_queue_action_list_t action_list + IREE_GUARDED_BY(action_mutex); // The worker thread that monitors incoming requests and issues ready actions // to the GPU. @@ -476,19 +476,16 @@ struct iree_hal_hip_pending_queue_actions_t { iree_thread_t* completion_thread; // The worker's working area; data exchange place with the parent thread. - iree_hal_hip_working_area_t working_area; + iree_hal_deferred_work_queue_working_area_t working_area; // Completion thread's working area. - iree_hal_hip_completion_area_t completion_area; + iree_hal_deferred_work_queue_completion_area_t completion_area; // Atomic of type iree_status_t. It is a sticky error. // Once set with an error, all subsequent actions that have not completed // will fail with this error. iree_status_t status IREE_GUARDED_BY(action_mutex); - // The associated hip device. - hipDevice_t device; - // The number of asynchronous work items that are scheduled and not // complete. // These are @@ -504,35 +501,36 @@ struct iree_hal_hip_pending_queue_actions_t { bool exit_requested IREE_GUARDED_BY(action_mutex); }; -iree_status_t iree_hal_hip_pending_queue_actions_create( - const iree_hal_hip_dynamic_symbols_t* symbols, hipDevice_t device, +iree_status_t iree_hal_deferred_work_queue_create( + iree_hal_deferred_work_queue_device_interface_t* device_interface, iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, - iree_hal_hip_pending_queue_actions_t** out_actions) { - IREE_ASSERT_ARGUMENT(symbols); + iree_hal_deferred_work_queue_t** out_actions) { + IREE_ASSERT_ARGUMENT(device_interface); IREE_ASSERT_ARGUMENT(block_pool); IREE_ASSERT_ARGUMENT(out_actions); IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_pending_queue_actions_t* actions = NULL; + iree_hal_deferred_work_queue_t* actions = NULL; IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_allocator_malloc(host_allocator, sizeof(*actions), (void**)&actions)); actions->host_allocator = host_allocator; actions->block_pool = block_pool; - actions->symbols = symbols; - actions->device = device; + actions->device_interface = device_interface; iree_slim_mutex_initialize(&actions->action_mutex); memset(&actions->action_list, 0, sizeof(actions->action_list)); // Initialize the working area for the ready-list processing worker. - iree_hal_hip_working_area_t* working_area = &actions->working_area; - iree_hal_hip_working_area_initialize(host_allocator, device, symbols, - working_area); + iree_hal_deferred_work_queue_working_area_t* working_area = + &actions->working_area; + iree_hal_deferred_work_queue_working_area_initialize( + host_allocator, device_interface, working_area); - iree_hal_hip_completion_area_t* completion_area = &actions->completion_area; - iree_hal_hip_completion_area_initialize(host_allocator, device, symbols, - completion_area); + iree_hal_deferred_work_queue_completion_area_t* completion_area = + &actions->completion_area; + iree_hal_deferred_work_queue_completion_area_initialize( + host_allocator, device_interface, completion_area); // Create the ready-list processing worker itself. iree_thread_create_params_t params; @@ -540,44 +538,44 @@ iree_status_t iree_hal_hip_pending_queue_actions_create( params.name = IREE_SV("iree-hip-queue-worker"); params.create_suspended = false; iree_status_t status = iree_thread_create( - (iree_thread_entry_t)iree_hal_hip_worker_execute, actions, params, - actions->host_allocator, &actions->worker_thread); + (iree_thread_entry_t)iree_hal_deferred_work_queue_worker_execute, actions, + params, actions->host_allocator, &actions->worker_thread); params.name = IREE_SV("iree-hip-queue-completion"); params.create_suspended = false; if (iree_status_is_ok(status)) { status = iree_thread_create( - (iree_thread_entry_t)iree_hal_hip_completion_execute, actions, params, - actions->host_allocator, &actions->completion_thread); + (iree_thread_entry_t)iree_hal_deferred_work_queue_completion_execute, + actions, params, actions->host_allocator, &actions->completion_thread); } if (iree_status_is_ok(status)) { *out_actions = actions; } else { - iree_hal_hip_pending_queue_actions_destroy((iree_hal_resource_t*)actions); + iree_hal_deferred_work_queue_destroy(actions); } IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } -static iree_hal_hip_pending_queue_actions_t* -iree_hal_hip_pending_queue_actions_cast(iree_hal_resource_t* base_value) { - return (iree_hal_hip_pending_queue_actions_t*)base_value; +static iree_hal_deferred_work_queue_t* iree_hal_deferred_work_queue_cast( + iree_hal_resource_t* base_value) { + return (iree_hal_deferred_work_queue_t*)base_value; } -static void iree_hal_hip_pending_queue_actions_notify_worker_thread( - iree_hal_hip_working_area_t* working_area) { +static void iree_hal_deferred_work_queue_notify_worker_thread( + iree_hal_deferred_work_queue_working_area_t* working_area) { iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING, + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, iree_memory_order_release); iree_notification_post(&working_area->state_notification, IREE_ALL_WAITERS); } -static void iree_hal_hip_pending_queue_actions_notify_completion_thread( - iree_hal_hip_completion_area_t* completion_area) { +static void iree_hal_deferred_work_queue_notify_completion_thread( + iree_hal_deferred_work_queue_completion_area_t* completion_area) { iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING, + IREE_HAL_WORKER_STATE_WORKLOAD_PENDING, iree_memory_order_release); iree_notification_post(&completion_area->state_notification, IREE_ALL_WAITERS); @@ -585,55 +583,56 @@ static void iree_hal_hip_pending_queue_actions_notify_completion_thread( // Notifies worker and completion threads that there is work available to // process. -static void iree_hal_hip_pending_queue_actions_notify_threads( - iree_hal_hip_pending_queue_actions_t* actions) { - iree_hal_hip_pending_queue_actions_notify_worker_thread( - &actions->working_area); - iree_hal_hip_pending_queue_actions_notify_completion_thread( +static void iree_hal_deferred_work_queue_notify_threads( + iree_hal_deferred_work_queue_t* actions) { + iree_hal_deferred_work_queue_notify_worker_thread(&actions->working_area); + iree_hal_deferred_work_queue_notify_completion_thread( &actions->completion_area); } -static void iree_hal_hip_pending_queue_actions_request_exit( - iree_hal_hip_pending_queue_actions_t* actions) { +static void iree_hal_deferred_work_queue_request_exit( + iree_hal_deferred_work_queue_t* actions) { iree_slim_mutex_lock(&actions->action_mutex); actions->exit_requested = true; iree_slim_mutex_unlock(&actions->action_mutex); - iree_hal_hip_pending_queue_actions_notify_threads(actions); + iree_hal_deferred_work_queue_notify_threads(actions); } -void iree_hal_hip_pending_queue_actions_destroy( - iree_hal_resource_t* base_actions) { +void iree_hal_deferred_work_queue_destroy( + iree_hal_deferred_work_queue_t* work_queue) { IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_pending_queue_actions_t* actions = - iree_hal_hip_pending_queue_actions_cast(base_actions); - iree_allocator_t host_allocator = actions->host_allocator; + iree_allocator_t host_allocator = work_queue->host_allocator; // Request the workers to exit. - iree_hal_hip_pending_queue_actions_request_exit(actions); + iree_hal_deferred_work_queue_request_exit(work_queue); + + iree_thread_join(work_queue->worker_thread); + iree_thread_release(work_queue->worker_thread); - iree_thread_join(actions->worker_thread); - iree_thread_release(actions->worker_thread); + iree_thread_join(work_queue->completion_thread); + iree_thread_release(work_queue->completion_thread); - iree_thread_join(actions->completion_thread); - iree_thread_release(actions->completion_thread); + iree_hal_deferred_work_queue_working_area_deinitialize( + &work_queue->working_area, work_queue->host_allocator); + iree_hal_deferred_work_queue_completion_area_deinitialize( + &work_queue->completion_area, work_queue->device_interface, + work_queue->host_allocator); - iree_hal_hip_working_area_deinitialize(&actions->working_area, - actions->host_allocator); - iree_hal_hip_completion_area_deinitialize( - &actions->completion_area, actions->symbols, actions->host_allocator); + iree_slim_mutex_deinitialize(&work_queue->action_mutex); + iree_hal_deferred_work_queue_action_list_destroy( + work_queue->action_list.head); - iree_slim_mutex_deinitialize(&actions->action_mutex); - iree_hal_hip_queue_action_list_destroy(actions->action_list.head); - iree_allocator_free(host_allocator, actions); + work_queue->device_interface->vtable->destroy(work_queue->device_interface); + iree_allocator_free(host_allocator, work_queue); IREE_TRACE_ZONE_END(z0); } -static void iree_hal_hip_queue_action_destroy( - iree_hal_hip_queue_action_t* action) { +static void iree_hal_deferred_work_queue_action_destroy( + iree_hal_deferred_work_queue_action_t* action) { IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions; + iree_hal_deferred_work_queue_t* actions = action->owning_actions; iree_allocator_t host_allocator = actions->host_allocator; // Call user provided callback before releasing any resource. @@ -644,7 +643,7 @@ static void iree_hal_hip_queue_action_destroy( // Only release resources after callbacks have been issued. iree_hal_resource_set_free(action->resource_set); - iree_hal_hip_queue_action_clear_events(action); + iree_hal_deferred_work_queue_action_clear_events(action); iree_hal_resource_release(actions); @@ -653,17 +652,16 @@ static void iree_hal_hip_queue_action_destroy( IREE_TRACE_ZONE_END(z0); } -static void iree_hal_hip_queue_decrement_work_items_count( - iree_hal_hip_pending_queue_actions_t* actions) { +static void iree_hal_deferred_work_queue_decrement_work_items_count( + iree_hal_deferred_work_queue_t* actions) { iree_slim_mutex_lock(&actions->action_mutex); --actions->pending_work_items_count; iree_slim_mutex_unlock(&actions->action_mutex); } -iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution( - iree_hal_device_t* device, hipStream_t dispatch_stream, - iree_hal_hip_pending_queue_actions_t* actions, - iree_hal_hip_pending_action_cleanup_callback_t cleanup_callback, +iree_status_t iree_hal_deferred_work_queue_enque( + iree_hal_deferred_work_queue_t* actions, + iree_hal_deferred_work_queue_cleanup_callback_t cleanup_callback, void* callback_user_data, const iree_hal_semaphore_list_t wait_semaphore_list, const iree_hal_semaphore_list_t signal_semaphore_list, @@ -675,7 +673,7 @@ iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution( IREE_TRACE_ZONE_BEGIN(z0); // Embed captured tables in the action allocation. - iree_hal_hip_queue_action_t* action = NULL; + iree_hal_deferred_work_queue_action_t* action = NULL; const iree_host_size_t wait_semaphore_list_size = wait_semaphore_list.count * sizeof(*wait_semaphore_list.semaphores) + wait_semaphore_list.count * sizeof(*wait_semaphore_list.payload_values); @@ -705,12 +703,11 @@ iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution( uint8_t* action_ptr = (uint8_t*)action + sizeof(*action); action->owning_actions = actions; - action->state = IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE; + action->device_interface = actions->device_interface; + action->state = IREE_HAL_QUEUE_ACTION_STATE_ALIVE; action->cleanup_callback = cleanup_callback; action->callback_user_data = callback_user_data; - action->kind = IREE_HAL_HIP_QUEUE_ACTION_TYPE_EXECUTION; - action->device = device; - action->dispatch_hip_stream = dispatch_stream; + action->kind = IREE_HAL_QUEUE_ACTION_TYPE_EXECUTION; // Initialize scratch fields. action->event_count = 0; @@ -809,9 +806,10 @@ iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution( status = iree_make_status( IREE_STATUS_ABORTED, "can not issue more executions, exit already requested"); - iree_hal_hip_queue_action_fail_locked(action, status); + iree_hal_deferred_work_queue_action_fail_locked(action, status); } else { - iree_hal_hip_queue_action_list_push_back(&actions->action_list, action); + iree_hal_deferred_work_queue_action_list_push_back(&actions->action_list, + action); // One work item is the callback that makes it across from the // completion thread. // The other is the cleanup of the action. @@ -829,8 +827,8 @@ iree_status_t iree_hal_hip_pending_queue_actions_enqueue_execution( } // Does not consume |status|. -static void iree_hal_hip_pending_queue_actions_fail_status_locked( - iree_hal_hip_pending_queue_actions_t* actions, iree_status_t status) { +static void iree_hal_deferred_work_queue_fail_status_locked( + iree_hal_deferred_work_queue_t* actions, iree_status_t status) { if (iree_status_is_ok(actions->status) && status != actions->status) { actions->status = iree_status_clone(status); } @@ -840,10 +838,10 @@ static void iree_hal_hip_pending_queue_actions_fail_status_locked( // Does not consume |status|. // Decrements pending work items count accordingly based on the unfulfilled // number of work items. -static void iree_hal_hip_queue_action_fail_locked( - iree_hal_hip_queue_action_t* action, iree_status_t status) { +static void iree_hal_deferred_work_queue_action_fail_locked( + iree_hal_deferred_work_queue_action_t* action, iree_status_t status) { IREE_ASSERT(!iree_status_is_ok(status)); - iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions; + iree_hal_deferred_work_queue_t* actions = action->owning_actions; // Unlock since failing the semaphore will use |actions|. iree_slim_mutex_unlock(&actions->action_mutex); @@ -852,10 +850,10 @@ static void iree_hal_hip_queue_action_fail_locked( iree_host_size_t work_items_remaining = 0; switch (action->state) { - case IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE: + case IREE_HAL_QUEUE_ACTION_STATE_ALIVE: work_items_remaining = total_work_items_to_complete_an_action; break; - case IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE: + case IREE_HAL_QUEUE_ACTION_STATE_ZOMBIE: work_items_remaining = 1; break; default: @@ -864,99 +862,103 @@ static void iree_hal_hip_queue_action_fail_locked( } iree_slim_mutex_lock(&actions->action_mutex); action->owning_actions->pending_work_items_count -= work_items_remaining; - iree_hal_hip_pending_queue_actions_fail_status_locked(actions, status); - iree_hal_hip_queue_action_destroy(action); + iree_hal_deferred_work_queue_fail_status_locked(actions, status); + iree_hal_deferred_work_queue_action_destroy(action); } // Fails and destroys all actions. // Does not consume |status|. -static void iree_hal_hip_queue_action_fail(iree_hal_hip_queue_action_t* action, - iree_status_t status) { - iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions; +static void iree_hal_deferred_work_queue_action_fail( + iree_hal_deferred_work_queue_action_t* action, iree_status_t status) { + iree_hal_deferred_work_queue_t* actions = action->owning_actions; iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_hip_queue_action_fail_locked(action, status); + iree_hal_deferred_work_queue_action_fail_locked(action, status); iree_slim_mutex_unlock(&actions->action_mutex); } // Fails and destroys all actions. // Does not consume |status|. -static void iree_hal_hip_queue_action_raw_list_fail_locked( - iree_hal_hip_queue_action_t* head_action, iree_status_t status) { +static void iree_hal_deferred_work_queue_action_raw_list_fail_locked( + iree_hal_deferred_work_queue_action_t* head_action, iree_status_t status) { while (head_action) { - iree_hal_hip_queue_action_t* next_action = head_action->next; - iree_hal_hip_queue_action_fail_locked(head_action, status); + iree_hal_deferred_work_queue_action_t* next_action = head_action->next; + iree_hal_deferred_work_queue_action_fail_locked(head_action, status); head_action = next_action; } } // Fails and destroys all actions. // Does not consume |status|. -static void iree_hal_hip_ready_action_list_fail_locked( - iree_hal_hip_entry_list_t* list, iree_status_t status) { - iree_hal_hip_entry_list_node_t* entry = iree_hal_hip_entry_list_pop(list); +static void iree_hal_deferred_work_queue_ready_action_list_fail_locked( + iree_hal_deferred_work_queue_entry_list_t* list, iree_status_t status) { + iree_hal_deferred_work_queue_entry_list_node_t* entry = + iree_hal_deferred_work_queue_entry_list_pop(list); while (entry) { - iree_hal_hip_queue_action_raw_list_fail_locked(entry->ready_list_head, - status); - entry = iree_hal_hip_entry_list_pop(list); + iree_hal_deferred_work_queue_action_raw_list_fail_locked( + entry->ready_list_head, status); + entry = iree_hal_deferred_work_queue_entry_list_pop(list); } } // Fails and destroys all actions. // Does not consume |status|. -static void iree_hal_hip_queue_action_list_fail_locked( - iree_hal_hip_queue_action_list_t* list, iree_status_t status) { - iree_hal_hip_queue_action_t* action; - if (iree_hal_hip_queue_action_list_is_empty(list)) { +static void iree_hal_deferred_work_queue_action_list_fail_locked( + iree_hal_deferred_work_queue_action_list_t* list, iree_status_t status) { + iree_hal_deferred_work_queue_action_t* action; + if (iree_hal_deferred_work_queue_action_list_is_empty(list)) { return; } do { - action = iree_hal_hip_queue_action_list_pop_front(list); - iree_hal_hip_queue_action_fail_locked(action, status); + action = iree_hal_deferred_work_queue_action_list_pop_front(list); + iree_hal_deferred_work_queue_action_fail_locked(action, status); } while (action); } // Fails and destroys all actions and sets status of |actions|. // Does not consume |status|. // Assumes the caller is holding the action_mutex. -static void iree_hal_hip_pending_queue_actions_fail_locked( - iree_hal_hip_pending_queue_actions_t* actions, iree_status_t status) { - iree_hal_hip_pending_queue_actions_fail_status_locked(actions, status); - iree_hal_hip_queue_action_list_fail_locked(&actions->action_list, status); - iree_hal_hip_ready_action_list_fail_locked( +static void iree_hal_deferred_work_queue_fail_locked( + iree_hal_deferred_work_queue_t* actions, iree_status_t status) { + iree_hal_deferred_work_queue_fail_status_locked(actions, status); + iree_hal_deferred_work_queue_action_list_fail_locked(&actions->action_list, + status); + iree_hal_deferred_work_queue_ready_action_list_fail_locked( &actions->working_area.ready_worklist, status); } // Does not consume |status|. -static void iree_hal_hip_pending_queue_actions_fail( - iree_hal_hip_pending_queue_actions_t* actions, iree_status_t status) { +static void iree_hal_deferred_work_queue_fail( + iree_hal_deferred_work_queue_t* actions, iree_status_t status) { iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_hip_pending_queue_actions_fail_locked(actions, status); + iree_hal_deferred_work_queue_fail_locked(actions, status); iree_slim_mutex_unlock(&actions->action_mutex); } // Releases resources after action completion on the GPU and advances timeline -// and pending actions queue. -static iree_status_t iree_hal_hip_execution_device_signal_host_callback( +// and deferred work queue. +static iree_status_t +iree_hal_deferred_work_queue_execution_device_signal_host_callback( iree_status_t status, void* user_data) { IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_queue_action_t* action = (iree_hal_hip_queue_action_t*)user_data; - IREE_ASSERT_EQ(action->kind, IREE_HAL_HIP_QUEUE_ACTION_TYPE_EXECUTION); - IREE_ASSERT_EQ(action->state, IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE); - iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions; + iree_hal_deferred_work_queue_action_t* action = + (iree_hal_deferred_work_queue_action_t*)user_data; + IREE_ASSERT_EQ(action->kind, IREE_HAL_QUEUE_ACTION_TYPE_EXECUTION); + IREE_ASSERT_EQ(action->state, IREE_HAL_QUEUE_ACTION_STATE_ALIVE); + iree_hal_deferred_work_queue_t* actions = action->owning_actions; if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_queue_action_fail(action, status); + iree_hal_deferred_work_queue_action_fail(action, status); IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } // Need to signal the list before zombifying the action, because in the mean // time someone else may issue the pending queue actions. - // If we push first to the pending actions list, the cleanup of this action + // If we push first to the deferred work list, the cleanup of this action // may run while we are still using the semaphore list, causing a crash. status = iree_hal_semaphore_list_signal(action->signal_semaphore_list); if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_queue_action_fail(action, status); + iree_hal_deferred_work_queue_action_fail(action, status); IREE_TRACE_ZONE_END(z0); return status; } @@ -965,46 +967,46 @@ static iree_status_t iree_hal_hip_execution_device_signal_host_callback( // the worker thread clean it up. Note that this is necessary because cleanup // may involve GPU API calls like buffer releasing or unregistering, so we can // not inline it here. - action->state = IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE; + action->state = IREE_HAL_QUEUE_ACTION_STATE_ZOMBIE; iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_hip_queue_action_list_push_back(&actions->action_list, action); + iree_hal_deferred_work_queue_action_list_push_back(&actions->action_list, + action); // The callback (work item) is complete. --actions->pending_work_items_count; iree_slim_mutex_unlock(&actions->action_mutex); // We need to trigger execution of this action again, so it gets cleaned up. - status = iree_hal_hip_pending_queue_actions_issue(actions); + status = iree_hal_deferred_work_queue_issue(actions); IREE_TRACE_ZONE_END(z0); return status; } // Issues the given kernel dispatch |action| to the GPU. -static iree_status_t iree_hal_hip_pending_queue_actions_issue_execution( - iree_hal_hip_queue_action_t* action) { - IREE_ASSERT_EQ(action->kind, IREE_HAL_HIP_QUEUE_ACTION_TYPE_EXECUTION); +static iree_status_t iree_hal_deferred_work_queue_issue_execution( + iree_hal_deferred_work_queue_action_t* action) { + IREE_ASSERT_EQ(action->kind, IREE_HAL_QUEUE_ACTION_TYPE_EXECUTION); IREE_ASSERT_EQ(action->is_pending, false); - iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions; - const iree_hal_hip_dynamic_symbols_t* symbols = actions->symbols; + iree_hal_deferred_work_queue_t* actions = action->owning_actions; + iree_hal_deferred_work_queue_device_interface_t* device_interface = + actions->device_interface; IREE_TRACE_ZONE_BEGIN(z0); // No need to lock given that this action is already detched from the pending // actions list; so only this thread is seeing it now. - // First wait all the device hipEvent_t in the dispatch stream. + // First wait all the device events in the dispatch stream. for (iree_host_size_t i = 0; i < action->event_count; ++i) { - IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, - hipStreamWaitEvent(action->dispatch_hip_stream, - iree_hal_hip_event_handle(action->events[i]), - /*flags=*/0), - "hipStreamWaitEvent"); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, device_interface->vtable->device_wait_on_host_event( + device_interface, action->wait_events[i])); } // Then launch all command buffers to the dispatch stream. IREE_TRACE_ZONE_BEGIN(z_dispatch_command_buffers); IREE_TRACE_ZONE_APPEND_TEXT(z_dispatch_command_buffers, "dispatch_command_buffers"); + for (iree_host_size_t i = 0; i < action->payload.execution.count; ++i) { iree_hal_command_buffer_t* command_buffer = action->payload.execution.command_buffers[i]; @@ -1012,22 +1014,7 @@ static iree_status_t iree_hal_hip_pending_queue_actions_issue_execution( action->payload.execution.binding_tables ? action->payload.execution.binding_tables[i] : iree_hal_buffer_binding_table_empty(); - if (iree_hal_hip_stream_command_buffer_isa(command_buffer)) { - // Nothing much to do for an inline command buffer; all the work has - // already been submitted. When we support semaphores we'll still need to - // signal their completion but do not have to worry about any waits: if - // there were waits we wouldn't have been able to execute inline! We do - // notify that the commands were "submitted" so we can make sure to clean - // up our trace events. - iree_hal_hip_stream_notify_submitted_commands(command_buffer); - } else if (iree_hal_hip_graph_command_buffer_isa(command_buffer)) { - hipGraphExec_t exec = - iree_hal_hip_graph_command_buffer_handle(command_buffer); - IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, hipGraphLaunch(exec, action->dispatch_hip_stream), - "hipGraphLaunch"); - iree_hal_hip_graph_tracing_notify_submitted_commands(command_buffer); - } else { + if (iree_hal_deferred_command_buffer_isa(command_buffer)) { iree_hal_command_buffer_t* stream_command_buffer = NULL; iree_hal_command_buffer_mode_t mode = iree_hal_command_buffer_mode(command_buffer) | @@ -1039,38 +1026,50 @@ static iree_status_t iree_hal_hip_pending_queue_actions_issue_execution( ? IREE_HAL_COMMAND_BUFFER_MODE_UNVALIDATED : 0); IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_device_create_stream_command_buffer( - action->device, mode, IREE_HAL_COMMAND_CATEGORY_ANY, - /*binding_capacity=*/0, &stream_command_buffer)); + z0, device_interface->vtable->create_stream_command_buffer( + device_interface, mode, IREE_HAL_COMMAND_CATEGORY_ANY, + &stream_command_buffer)) IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_resource_set_insert(action->resource_set, 1, &stream_command_buffer)); + IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_deferred_command_buffer_apply( command_buffer, stream_command_buffer, binding_table)); - iree_hal_hip_stream_notify_submitted_commands(stream_command_buffer); - // The stream_command_buffer is going to be retained by - // the action->resource_set and deleted after the action - // completes. - iree_hal_resource_release(stream_command_buffer); + command_buffer = stream_command_buffer; + } else { + iree_hal_resource_retain(command_buffer); } + + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, device_interface->vtable->submit_command_buffer(device_interface, + command_buffer)); + + // The stream_command_buffer is going to be retained by + // the action->resource_set and deleted after the action + // completes. + iree_hal_resource_release(command_buffer); } + IREE_TRACE_ZONE_END(z_dispatch_command_buffers); - hipEvent_t completion_event = NULL; - // Last record hipEvent_t signals in the dispatch stream. + iree_hal_deferred_work_queue_native_event_t completion_event = NULL; + // Last record event signals in the dispatch stream. for (iree_host_size_t i = 0; i < action->signal_semaphore_list.count; ++i) { - // Grab a hipEvent_t for this semaphore value signaling. - hipEvent_t event = NULL; + // Grab an event for this semaphore value signaling. + iree_hal_deferred_work_queue_native_event_t event = NULL; + IREE_RETURN_AND_END_ZONE_IF_ERROR( - z0, iree_hal_hip_event_semaphore_acquire_timepoint_device_signal( - action->signal_semaphore_list.semaphores[i], + z0, + device_interface->vtable + ->semaphore_acquire_timepoint_device_signal_native_event( + device_interface, action->signal_semaphore_list.semaphores[i], action->signal_semaphore_list.payload_values[i], &event)); // Record the event signaling in the dispatch stream. - IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, hipEventRecord(event, action->dispatch_hip_stream), - "hipEventRecord"); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, + device_interface->vtable->record_native_event(device_interface, event)); completion_event = event; } @@ -1079,19 +1078,17 @@ static iree_status_t iree_hal_hip_pending_queue_actions_issue_execution( // we can re-use those as a wait event. However if there are no signals // then we create one. In my testing this is not a common case. if (IREE_UNLIKELY(!completion_event)) { - IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, - hipEventCreateWithFlags(&completion_event, hipEventDisableTiming), - "hipEventCreateWithFlags"); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, device_interface->vtable->create_native_event(device_interface, + &completion_event)); created_event = true; } - IREE_HIP_RETURN_AND_END_ZONE_IF_ERROR( - z0, symbols, - hipEventRecord(completion_event, action->dispatch_hip_stream), - "hipEventRecord"); + IREE_RETURN_AND_END_ZONE_IF_ERROR( + z0, device_interface->vtable->record_native_event(device_interface, + completion_event)); - iree_hal_hip_completion_list_node_t* entry = NULL; + iree_hal_deferred_work_queue_completion_list_node_t* entry = NULL; // TODO: avoid host allocator malloc; use some pool for the allocation. iree_status_t status = iree_allocator_malloc(actions->host_allocator, sizeof(*entry), (void**)&entry); @@ -1103,14 +1100,15 @@ static iree_status_t iree_hal_hip_pending_queue_actions_issue_execution( // Now push the ready list to the worker and have it to issue the actions to // the GPU. - entry->event = completion_event; + entry->native_event = completion_event; entry->created_event = created_event; - entry->callback = iree_hal_hip_execution_device_signal_host_callback; + entry->callback = + iree_hal_deferred_work_queue_execution_device_signal_host_callback; entry->user_data = action; - iree_hal_hip_completion_list_push(&actions->completion_area.completion_list, - entry); + iree_hal_deferred_work_queue_completion_list_push( + &actions->completion_area.completion_list, entry); - iree_hal_hip_pending_queue_actions_notify_completion_thread( + iree_hal_deferred_work_queue_notify_completion_thread( &actions->completion_area); IREE_TRACE_ZONE_END(z0); @@ -1118,38 +1116,39 @@ static iree_status_t iree_hal_hip_pending_queue_actions_issue_execution( } // Performs the given cleanup |action| on the CPU. -static void iree_hal_hip_pending_queue_actions_issue_cleanup( - iree_hal_hip_queue_action_t* action) { - iree_hal_hip_pending_queue_actions_t* actions = action->owning_actions; +static void iree_hal_deferred_work_queue_issue_cleanup( + iree_hal_deferred_work_queue_action_t* action) { + iree_hal_deferred_work_queue_t* actions = action->owning_actions; IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_queue_action_destroy(action); + iree_hal_deferred_work_queue_action_destroy(action); // Now we fully executed and cleaned up this action. Decrease the work items // counter. - iree_hal_hip_queue_decrement_work_items_count(actions); + iree_hal_deferred_work_queue_decrement_work_items_count(actions); IREE_TRACE_ZONE_END(z0); } -iree_status_t iree_hal_hip_pending_queue_actions_issue( - iree_hal_hip_pending_queue_actions_t* actions) { +iree_status_t iree_hal_deferred_work_queue_issue( + iree_hal_deferred_work_queue_t* actions) { IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_queue_action_list_t pending_list = {NULL, NULL}; - iree_hal_hip_queue_action_list_t ready_list = {NULL, NULL}; + iree_hal_deferred_work_queue_action_list_t pending_list = {NULL, NULL}; + iree_hal_deferred_work_queue_action_list_t ready_list = {NULL, NULL}; iree_slim_mutex_lock(&actions->action_mutex); - if (iree_hal_hip_queue_action_list_is_empty(&actions->action_list)) { + if (iree_hal_deferred_work_queue_action_list_is_empty( + &actions->action_list)) { iree_slim_mutex_unlock(&actions->action_mutex); IREE_TRACE_ZONE_END(z0); return iree_ok_status(); } if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) { - iree_hal_hip_queue_action_list_fail_locked(&actions->action_list, - actions->status); + iree_hal_deferred_work_queue_action_list_fail_locked(&actions->action_list, + actions->status); iree_slim_mutex_unlock(&actions->action_mutex); IREE_TRACE_ZONE_END(z0); return iree_ok_status(); @@ -1157,9 +1156,11 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( iree_status_t status = iree_ok_status(); // Scan through the list and categorize actions into pending and ready lists. - while (!iree_hal_hip_queue_action_list_is_empty(&actions->action_list)) { - iree_hal_hip_queue_action_t* action = - iree_hal_hip_queue_action_list_pop_front(&actions->action_list); + while (!iree_hal_deferred_work_queue_action_list_is_empty( + &actions->action_list)) { + iree_hal_deferred_work_queue_action_t* action = + iree_hal_deferred_work_queue_action_list_pop_front( + &actions->action_list); iree_hal_semaphore_t** semaphores = action->wait_semaphore_list.semaphores; uint64_t* values = action->wait_semaphore_list.payload_values; @@ -1170,7 +1171,7 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( // Cleanup actions are immediately ready to release. Otherwise, look at all // wait semaphores to make sure that they are either already ready or we can // wait on a device event. - if (action->state == IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE) { + if (action->state == IREE_HAL_QUEUE_ACTION_STATE_ALIVE) { for (iree_host_size_t i = 0; i < action->wait_semaphore_list.count; ++i) { // If this semaphore has already signaled past the desired value, we can // just ignore it. @@ -1178,7 +1179,8 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( iree_status_t semaphore_status = iree_hal_semaphore_query(semaphores[i], &value); if (IREE_UNLIKELY(!iree_status_is_ok(semaphore_status))) { - iree_hal_hip_queue_action_fail_locked(action, semaphore_status); + iree_hal_deferred_work_queue_action_fail_locked(action, + semaphore_status); iree_status_ignore(semaphore_status); action_failed = true; break; @@ -1191,33 +1193,35 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( continue; } - // Try to acquire a HIP event from an existing device signal timepoint. + // Try to acquire an event from an existing device signal timepoint. // If so, we can use that event to wait on the device. // Otherwise, this action is still not ready for execution. // Before issuing recording on a stream, an event represents an empty // set of work so waiting on it will just return success. - // Here we must guarantee the HIP event is indeed recorded, which means + // Here we must guarantee the event is indeed recorded, which means // it's associated with some already present device signal timepoint on // the semaphore timeline. - iree_hal_hip_event_t* wait_event = NULL; - if (!iree_hal_hip_semaphore_acquire_event_host_wait( - semaphores[i], values[i], &wait_event)) { + iree_hal_deferred_work_queue_host_device_event_t wait_event = NULL; + if (!action->device_interface->vtable->acquire_host_wait_event( + action->device_interface, semaphores[i], values[i], + &wait_event)) { action->is_pending = true; break; } if (IREE_UNLIKELY(action->event_count >= - IREE_HAL_HIP_MAX_WAIT_EVENT_COUNT)) { + IREE_HAL_MAX_WAIT_EVENT_COUNT)) { status = iree_make_status( IREE_STATUS_RESOURCE_EXHAUSTED, "exceeded maximum queue action wait event limit"); - iree_hal_hip_event_release(wait_event); + action->device_interface->vtable->release_wait_event( + action->device_interface, wait_event); if (iree_status_is_ok(actions->status)) { actions->status = status; } - iree_hal_hip_queue_action_fail_locked(action, status); + iree_hal_deferred_work_queue_action_fail_locked(action, status); break; } - action->events[action->event_count++] = wait_event; + action->wait_events[action->event_count++] = wait_event; // Remove the wait timepoint as we have a corresponding event that we // will wait on. @@ -1228,10 +1232,11 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( if (IREE_UNLIKELY(!iree_status_is_ok(actions->status))) { if (!action_failed) { - iree_hal_hip_queue_action_fail_locked(action, actions->status); + iree_hal_deferred_work_queue_action_fail_locked(action, + actions->status); } - iree_hal_hip_queue_action_list_fail_locked(&actions->action_list, - actions->status); + iree_hal_deferred_work_queue_action_list_fail_locked( + &actions->action_list, actions->status); break; } @@ -1240,9 +1245,9 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( } if (action->is_pending) { - iree_hal_hip_queue_action_list_push_back(&pending_list, action); + iree_hal_deferred_work_queue_action_list_push_back(&pending_list, action); } else { - iree_hal_hip_queue_action_list_push_back(&ready_list, action); + iree_hal_deferred_work_queue_action_list_push_back(&ready_list, action); } } @@ -1257,7 +1262,7 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( return status; } - iree_hal_hip_entry_list_node_t* entry = NULL; + iree_hal_deferred_work_queue_entry_list_node_t* entry = NULL; // TODO: avoid host allocator malloc; use some pool for the allocation. if (iree_status_is_ok(status)) { status = iree_allocator_malloc(actions->host_allocator, sizeof(*entry), @@ -1266,8 +1271,8 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( if (IREE_UNLIKELY(!iree_status_is_ok(status))) { iree_slim_mutex_lock(&actions->action_mutex); - iree_hal_hip_pending_queue_actions_fail_status_locked(actions, status); - iree_hal_hip_queue_action_list_fail_locked(&ready_list, status); + iree_hal_deferred_work_queue_fail_status_locked(actions, status); + iree_hal_deferred_work_queue_action_list_fail_locked(&ready_list, status); iree_slim_mutex_unlock(&actions->action_mutex); IREE_TRACE_ZONE_END(z0); return status; @@ -1276,10 +1281,10 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( // Now push the ready list to the worker and have it to issue the actions to // the GPU. entry->ready_list_head = ready_list.head; - iree_hal_hip_entry_list_push(&actions->working_area.ready_worklist, entry); + iree_hal_deferred_work_queue_entry_list_push( + &actions->working_area.ready_worklist, entry); - iree_hal_hip_pending_queue_actions_notify_worker_thread( - &actions->working_area); + iree_hal_deferred_work_queue_notify_worker_thread(&actions->working_area); IREE_TRACE_ZONE_END(z0); return status; @@ -1289,29 +1294,29 @@ iree_status_t iree_hal_hip_pending_queue_actions_issue( // Worker routines //===----------------------------------------------------------------------===// -static bool iree_hal_hip_worker_has_incoming_request( - iree_hal_hip_working_area_t* working_area) { - iree_hal_hip_worker_state_t value = iree_atomic_load_int32( +static bool iree_hal_deferred_work_queue_worker_has_incoming_request( + iree_hal_deferred_work_queue_working_area_t* working_area) { + iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load_int32( &working_area->worker_state, iree_memory_order_acquire); - return value == IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING; + return value == IREE_HAL_WORKER_STATE_WORKLOAD_PENDING; } -static bool iree_hal_hip_completion_has_incoming_request( - iree_hal_hip_completion_area_t* completion_area) { - iree_hal_hip_worker_state_t value = iree_atomic_load_int32( +static bool iree_hal_deferred_work_queue_completion_has_incoming_request( + iree_hal_deferred_work_queue_completion_area_t* completion_area) { + iree_hal_deferred_work_queue_worker_state_t value = iree_atomic_load_int32( &completion_area->worker_state, iree_memory_order_acquire); - return value == IREE_HAL_HIP_WORKER_STATE_WORKLOAD_PENDING; + return value == IREE_HAL_WORKER_STATE_WORKLOAD_PENDING; } // Processes all ready actions in the given |worklist|. -static void iree_hal_hip_worker_process_ready_list( - iree_hal_hip_pending_queue_actions_t* actions) { +static void iree_hal_deferred_work_queue_worker_process_ready_list( + iree_hal_deferred_work_queue_t* actions) { IREE_TRACE_ZONE_BEGIN(z0); iree_slim_mutex_lock(&actions->action_mutex); iree_status_t status = actions->status; if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_ready_action_list_fail_locked( + iree_hal_deferred_work_queue_ready_action_list_fail_locked( &actions->working_area.ready_worklist, status); iree_slim_mutex_unlock(&actions->action_mutex); iree_status_ignore(status); @@ -1320,27 +1325,28 @@ static void iree_hal_hip_worker_process_ready_list( iree_slim_mutex_unlock(&actions->action_mutex); while (true) { - iree_hal_hip_entry_list_node_t* entry = - iree_hal_hip_entry_list_pop(&actions->working_area.ready_worklist); + iree_hal_deferred_work_queue_entry_list_node_t* entry = + iree_hal_deferred_work_queue_entry_list_pop( + &actions->working_area.ready_worklist); if (!entry) break; // Process the current batch of ready actions. while (entry->ready_list_head) { - iree_hal_hip_queue_action_t* action = - iree_hal_hip_entry_list_node_pop_front(entry); + iree_hal_deferred_work_queue_action_t* action = + iree_hal_deferred_work_queue_entry_list_node_pop_front(entry); switch (action->state) { - case IREE_HAL_HIP_QUEUE_ACTION_STATE_ALIVE: - status = iree_hal_hip_pending_queue_actions_issue_execution(action); + case IREE_HAL_QUEUE_ACTION_STATE_ALIVE: + status = iree_hal_deferred_work_queue_issue_execution(action); break; - case IREE_HAL_HIP_QUEUE_ACTION_STATE_ZOMBIE: - iree_hal_hip_pending_queue_actions_issue_cleanup(action); + case IREE_HAL_QUEUE_ACTION_STATE_ZOMBIE: + iree_hal_deferred_work_queue_issue_cleanup(action); break; } if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_entry_list_node_push_front(entry, action); - iree_hal_hip_entry_list_push(&actions->working_area.ready_worklist, - entry); + iree_hal_deferred_work_queue_entry_list_node_push_front(entry, action); + iree_hal_deferred_work_queue_entry_list_push( + &actions->working_area.ready_worklist, entry); break; } } @@ -1353,31 +1359,31 @@ static void iree_hal_hip_worker_process_ready_list( } if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_pending_queue_actions_fail(actions, status); + iree_hal_deferred_work_queue_fail(actions, status); iree_status_ignore(status); } IREE_TRACE_ZONE_END(z0); } -static void iree_hal_hip_worker_process_completion( - iree_hal_hip_pending_queue_actions_t* actions) { +static void iree_hal_deferred_work_queue_worker_process_completion( + iree_hal_deferred_work_queue_t* actions) { IREE_TRACE_ZONE_BEGIN(z0); - iree_hal_hip_completion_list_t* worklist = + iree_hal_deferred_work_queue_completion_list_t* worklist = &actions->completion_area.completion_list; iree_slim_mutex_lock(&actions->action_mutex); iree_status_t status = iree_status_clone(actions->status); iree_slim_mutex_unlock(&actions->action_mutex); while (true) { - iree_hal_hip_completion_list_node_t* entry = - iree_hal_hip_completion_list_pop(worklist); + iree_hal_deferred_work_queue_completion_list_node_t* entry = + iree_hal_deferred_work_queue_completion_list_pop(worklist); if (!entry) break; if (IREE_LIKELY(iree_status_is_ok(status))) { - IREE_TRACE_ZONE_BEGIN_NAMED(z1, "hipEventSynchronize"); - status = IREE_HIP_RESULT_TO_STATUS(actions->symbols, - hipEventSynchronize(entry->event)); + IREE_TRACE_ZONE_BEGIN_NAMED(z1, "synchronize_native_event"); + status = actions->device_interface->vtable->synchronize_native_event( + actions->device_interface, entry->native_event); IREE_TRACE_ZONE_END(z1); } @@ -1386,14 +1392,14 @@ static void iree_hal_hip_worker_process_completion( if (IREE_UNLIKELY(entry->created_event)) { status = iree_status_join( - status, IREE_HIP_RESULT_TO_STATUS(actions->symbols, - hipEventDestroy(entry->event))); + status, actions->device_interface->vtable->destroy_native_event( + actions->device_interface, entry->native_event)); } iree_allocator_free(actions->host_allocator, entry); } if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_pending_queue_actions_fail(actions, status); + iree_hal_deferred_work_queue_fail(actions, status); iree_status_ignore(status); } @@ -1401,21 +1407,23 @@ static void iree_hal_hip_worker_process_completion( } // The main function for the completion worker thread. -static int iree_hal_hip_completion_execute( - iree_hal_hip_pending_queue_actions_t* actions) { - iree_hal_hip_completion_area_t* completion_area = &actions->completion_area; +static int iree_hal_deferred_work_queue_completion_execute( + iree_hal_deferred_work_queue_t* actions) { + iree_hal_deferred_work_queue_completion_area_t* completion_area = + &actions->completion_area; - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( - actions->symbols, hipSetDevice(actions->device), "hipSetDevice"); + iree_status_t status = actions->device_interface->vtable->bind_to_thread( + actions->device_interface); if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_pending_queue_actions_fail(actions, status); + iree_hal_deferred_work_queue_fail(actions, status); iree_status_ignore(status); } while (true) { iree_notification_await( &completion_area->state_notification, - (iree_condition_fn_t)iree_hal_hip_completion_has_incoming_request, + (iree_condition_fn_t) + iree_hal_deferred_work_queue_completion_has_incoming_request, completion_area, iree_infinite_timeout()); // Immediately flip the state to idle waiting if and only if the previous @@ -1424,9 +1432,9 @@ static int iree_hal_hip_completion_execute( // ready list processing but before overwriting the state from this worker // thread. iree_atomic_store_int32(&completion_area->worker_state, - IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING, + IREE_HAL_WORKER_STATE_IDLE_WAITING, iree_memory_order_release); - iree_hal_hip_worker_process_completion(actions); + iree_hal_deferred_work_queue_worker_process_completion(actions); iree_slim_mutex_lock(&actions->action_mutex); if (IREE_UNLIKELY(actions->exit_requested && @@ -1441,18 +1449,19 @@ static int iree_hal_hip_completion_execute( } // The main function for the ready-list processing worker thread. -static int iree_hal_hip_worker_execute( - iree_hal_hip_pending_queue_actions_t* actions) { - iree_hal_hip_working_area_t* working_area = &actions->working_area; - - // Hip stores thread-local data based on the device. Some hip commands pull - // the device from there, and it defaults to device 0 (e.g. hipEventCreate), - // this will cause failures when using it with other devices (or streams from - // other devices). Force the correct device onto this thread. - iree_status_t status = IREE_HIP_RESULT_TO_STATUS( - actions->symbols, hipSetDevice(actions->device), "hipSetDevice"); +static int iree_hal_deferred_work_queue_worker_execute( + iree_hal_deferred_work_queue_t* actions) { + iree_hal_deferred_work_queue_working_area_t* working_area = + &actions->working_area; + + // Some APIs store thread-local data. Allow the interface to bind + // the thread-local data once for this thread rather than having to + // do it every call. + iree_status_t status = actions->device_interface->vtable->bind_to_thread( + actions->device_interface); + if (IREE_UNLIKELY(!iree_status_is_ok(status))) { - iree_hal_hip_pending_queue_actions_fail(actions, status); + iree_hal_deferred_work_queue_fail(actions, status); iree_status_ignore(status); // We can safely exit here because there are no actions in flight yet. return -1; @@ -1462,13 +1471,14 @@ static int iree_hal_hip_worker_execute( // Block waiting for incoming requests. // // TODO: When exit is requested with - // IREE_HAL_HIP_WORKER_STATE_EXIT_REQUESTED + // IREE_HAL_WORKER_STATE_EXIT_REQUESTED // we will return immediately causing a busy wait and hogging the CPU. // We need to properly wait for action cleanups to be scheduled from the // host stream callbacks. iree_notification_await( &working_area->state_notification, - (iree_condition_fn_t)iree_hal_hip_worker_has_incoming_request, + (iree_condition_fn_t) + iree_hal_deferred_work_queue_worker_has_incoming_request, working_area, iree_infinite_timeout()); // Immediately flip the state to idle waiting if and only if the previous @@ -1477,16 +1487,16 @@ static int iree_hal_hip_worker_execute( // ready list processing but before overwriting the state from this worker // thread. iree_atomic_store_int32(&working_area->worker_state, - IREE_HAL_HIP_WORKER_STATE_IDLE_WAITING, + IREE_HAL_WORKER_STATE_IDLE_WAITING, iree_memory_order_release); - iree_hal_hip_worker_process_ready_list(actions); + iree_hal_deferred_work_queue_worker_process_ready_list(actions); iree_slim_mutex_lock(&actions->action_mutex); if (IREE_UNLIKELY(actions->exit_requested && !actions->pending_work_items_count)) { iree_slim_mutex_unlock(&actions->action_mutex); - iree_hal_hip_pending_queue_actions_notify_completion_thread( + iree_hal_deferred_work_queue_notify_completion_thread( &actions->completion_area); return 0; } diff --git a/runtime/src/iree/hal/utils/deferred_work_queue.h b/runtime/src/iree/hal/utils/deferred_work_queue.h new file mode 100644 index 000000000000..c7c6615cd0e4 --- /dev/null +++ b/runtime/src/iree/hal/utils/deferred_work_queue.h @@ -0,0 +1,145 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#ifndef IREE_HAL_UTILS_DEFERRED_WORK_QUEUE_H_ +#define IREE_HAL_UTILS_DEFERRED_WORK_QUEUE_H_ + +#include "iree/base/api.h" +#include "iree/base/internal/arena.h" +#include "iree/hal/api.h" +#include "iree/hal/utils/semaphore_base.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +typedef struct iree_hal_deferred_work_queue_t iree_hal_deferred_work_queue_t; + +typedef struct iree_hal_deferred_work_queue_device_interface_vtable_t + iree_hal_deferred_work_queue_device_interface_vtable_t; + +// This interface is used to allow the deferred work queue to interact with +// a specific driver. +// Calls to this vtable may be made from the deferred work queue on +// multile threads simultaneously and so these functions must be thread +// safe. +// Calls to this interface will either come from a thread that has had +// bind_to_thread called on it or as a side-effect from one of the public +// functions on the deferred work queue. +typedef struct iree_hal_deferred_work_queue_device_interface_t { + const iree_hal_deferred_work_queue_device_interface_vtable_t* vtable; +} iree_hal_deferred_work_queue_device_interface_t; + +typedef void* iree_hal_deferred_work_queue_native_event_t; +typedef void* iree_hal_deferred_work_queue_host_device_event_t; + +typedef struct iree_hal_deferred_work_queue_device_interface_vtable_t { + void (*destroy)(iree_hal_deferred_work_queue_device_interface_t*); + // Binds the device work queue to a thread. May be simulatneously + // bound to multiple threads. + iree_status_t(IREE_API_PTR* bind_to_thread)( + iree_hal_deferred_work_queue_device_interface_t* device_interface); + + // Creates a native device event. + iree_status_t(IREE_API_PTR* create_native_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_native_event_t* out_event); + + // Waits on a native device event. + iree_status_t(IREE_API_PTR* wait_native_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_native_event_t event); + + // Records a native device event. + iree_status_t(IREE_API_PTR* record_native_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_native_event_t event); + + // Synchronizes the thread on a native device event. + iree_status_t(IREE_API_PTR* synchronize_native_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_native_event_t event); + + // Destroys a native device event. + iree_status_t(IREE_API_PTR* destroy_native_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_native_event_t event); + + // Acquires a native device event for the given timepoint. + iree_status_t( + IREE_API_PTR* semaphore_acquire_timepoint_device_signal_native_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + struct iree_hal_semaphore_t*, uint64_t, + iree_hal_deferred_work_queue_native_event_t* out_event); + + // Get the device to wait on the event associated wit hthe host event. + iree_status_t(IREE_API_PTR* device_wait_on_host_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_host_device_event_t event); + + // Acquires a mixed host/device event for the given timepoint. + bool(IREE_API_PTR* acquire_host_wait_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + struct iree_hal_semaphore_t*, uint64_t, + iree_hal_deferred_work_queue_host_device_event_t* out_event); + + // Releases a mixed host/device event for the given timepoint. + void(IREE_API_PTR* release_wait_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_host_device_event_t event); + + // Returns a device-side event from the given host/device event. + iree_hal_deferred_work_queue_native_event_t( + IREE_API_PTR* native_event_from_wait_event)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_deferred_work_queue_host_device_event_t event); + + // Creates a command buffer to be used to record a submitted + // iree_hal_deferred_command_buffer. + iree_status_t(IREE_API_PTR* create_stream_command_buffer)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_command_buffer_mode_t mode, iree_hal_command_category_t category, + iree_hal_command_buffer_t** out_command_buffer); + + // Submits a command buffer to the device. + iree_status_t(IREE_API_PTR* submit_command_buffer)( + iree_hal_deferred_work_queue_device_interface_t* device_interface, + iree_hal_command_buffer_t* command_buffer); +} iree_hal_deferred_work_queue_device_interface_vtable_t; + +iree_status_t iree_hal_deferred_work_queue_create( + iree_hal_deferred_work_queue_device_interface_t* symbols, + iree_arena_block_pool_t* block_pool, iree_allocator_t host_allocator, + iree_hal_deferred_work_queue_t** out_queue); + +void iree_hal_deferred_work_queue_destroy( + iree_hal_deferred_work_queue_t* queue); + +typedef void(IREE_API_PTR* iree_hal_deferred_work_queue_cleanup_callback_t)( + void* user_data); + +// Enques command buffer submissions into the work queue to be executed +// once all semaphores have been satisfied. +iree_status_t iree_hal_deferred_work_queue_enque( + iree_hal_deferred_work_queue_t* deferred_work_queue, + iree_hal_deferred_work_queue_cleanup_callback_t cleanup_callback, + void* callback_userdata, + const iree_hal_semaphore_list_t wait_semaphore_list, + const iree_hal_semaphore_list_t signal_semaphore_list, + iree_host_size_t command_buffer_count, + iree_hal_command_buffer_t* const* command_buffers, + iree_hal_buffer_binding_table_t const* binding_tables); + +// Attempts to advance the work queue by processing using +// the current thread, rather than the worker thread. +iree_status_t iree_hal_deferred_work_queue_issue( + iree_hal_deferred_work_queue_t* deferred_work_queue); + +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // IREE_HAL_UTILS_DEFERRED_WORK_QUEUE_H_