From f8c62b8ae79fc8a353f243ce4a7c0d1d0b591f57 Mon Sep 17 00:00:00 2001 From: leodziki Date: Fri, 30 Aug 2024 23:46:14 +0000 Subject: [PATCH] Implement DropGuard to ensure future Completion To manage state transitions safely across `await` boundaries, we introduced a `DropGuard` wrapper. This ensures that even if a future is dropped before completion, it will continue executing in the background via `tokio::spawn`. This approach allows us to avoid the pitfalls of incomplete state transitions in critical futures without the overhead of spawning all futures immediately. The `DropGuard` checks if the inner future has completed; if not, it safely spawns the future in the background during its drop. This solution strikes a balance between stack-based future management and the necessity to complete crucial futures. --- .../tests/utils/mock_scheduler.rs | 10 ++--- nativelink-util/BUILD.bazel | 1 + nativelink-util/src/drop_guard.rs | 42 +++++++++++++++++++ nativelink-util/src/evicting_map.rs | 7 +++- nativelink-util/src/lib.rs | 1 + 5 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 nativelink-util/src/drop_guard.rs diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index fe0e37035..113a3d25c 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -17,12 +17,10 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::{make_input_err, Error}; use nativelink_metric::{MetricsComponent, RootMetricsComponent}; -use nativelink_util::{ - action_messages::{ActionInfo, OperationId}, - known_platform_property_provider::KnownPlatformPropertyProvider, - operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, - }, +use nativelink_util::action_messages::{ActionInfo, OperationId}; +use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter, }; use tokio::sync::{mpsc, Mutex}; diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index b5598b872..5a47a82e6 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -17,6 +17,7 @@ rust_library( "src/connection_manager.rs", "src/default_store_key_subscribe.rs", "src/digest_hasher.rs", + "src/drop_guard.rs", "src/evicting_map.rs", "src/fastcdc.rs", "src/fs.rs", diff --git a/nativelink-util/src/drop_guard.rs b/nativelink-util/src/drop_guard.rs new file mode 100644 index 000000000..b1c8d0a30 --- /dev/null +++ b/nativelink-util/src/drop_guard.rs @@ -0,0 +1,42 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub struct DropGuard { + future: Option>>, +} + +impl DropGuard { + pub fn new(future: F) -> Self { + Self { + future: Some(Box::pin(future)), + } + } +} + +impl Future for DropGuard { + type Output = F::Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(future) = self.future.as_mut() { + match future.as_mut().poll(cx) { + Poll::Ready(output) => { + self.future = None; // Set future to None after it completes + Poll::Ready(output) + } + Poll::Pending => Poll::Pending, + } + } else { + panic!("Future already completed"); + } + } +} + +impl Drop for DropGuard { + fn drop(&mut self) { + if let Some(future) = self.future.take() { + // Block on the future to ensure it completes. + futures::executor::block_on(future); + } + } +} diff --git a/nativelink-util/src/evicting_map.rs b/nativelink-util/src/evicting_map.rs index 90a1d5597..bc588edff 100644 --- a/nativelink-util/src/evicting_map.rs +++ b/nativelink-util/src/evicting_map.rs @@ -28,6 +28,7 @@ use nativelink_metric::MetricsComponent; use serde::{Deserialize, Serialize}; use tracing::{event, Level}; +use crate::drop_guard::DropGuard; use crate::instant_wrapper::InstantWrapper; use crate::metrics_utils::{Counter, CounterWithTime}; @@ -434,7 +435,11 @@ where data, }; - if let Some(old_item) = state.put(key, eviction_item).await { + let fut = state.put(key, eviction_item); + + let drop_guard = DropGuard::new(fut); + + if let Some(old_item) = drop_guard.await { replaced_items.push(old_item); } state.sum_store_size += new_item_size; diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 69f6edaa2..6cb6772aa 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -20,6 +20,7 @@ pub mod common; pub mod connection_manager; pub mod default_store_key_subscribe; pub mod digest_hasher; +pub mod drop_guard; pub mod evicting_map; pub mod fastcdc; pub mod fs;