Skip to content

Commit

Permalink
Implement DropProtectedFuture to ensure future completion
Browse files Browse the repository at this point in the history
We introduced a `DropProtectedFuture` wrapper to manage state transitions
safely across `await` boundaries. It ensures that the future will continue
executing in the background via `tokio::spawn` even if a future is dropped
before completion.
  • Loading branch information
leodziki committed Sep 16, 2024
1 parent 506a297 commit fcc7fd5
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 28 deletions.
10 changes: 4 additions & 6 deletions nativelink-scheduler/tests/utils/mock_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@ use std::sync::Arc;
use async_trait::async_trait;
use nativelink_error::{make_input_err, Error};
use nativelink_metric::{MetricsComponent, RootMetricsComponent};
use nativelink_util::{
action_messages::{ActionInfo, OperationId},
known_platform_property_provider::KnownPlatformPropertyProvider,
operation_state_manager::{
ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
},
use nativelink_util::action_messages::{ActionInfo, OperationId};
use nativelink_util::known_platform_property_provider::KnownPlatformPropertyProvider;
use nativelink_util::operation_state_manager::{
ActionStateResult, ActionStateResultStream, ClientStateManager, OperationFilter,
};
use tokio::sync::{mpsc, Mutex};

Expand Down
1 change: 1 addition & 0 deletions nativelink-util/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ rust_library(
"src/common.rs",
"src/connection_manager.rs",
"src/digest_hasher.rs",
"src/drop_protected_future.rs",
"src/evicting_map.rs",
"src/fastcdc.rs",
"src/fs.rs",
Expand Down
68 changes: 68 additions & 0 deletions nativelink-util/src/drop_protected_future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright 2024 The NativeLink Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use tracing::{Instrument, Span};

use crate::origin_context::{ContextAwareFuture, OriginContext};
use crate::spawn;
#[derive(Clone)]
pub struct DropProtectedFuture<F, T>
where
T: Send + 'static,
F: Future<Output = T> + Send + 'static + Unpin + Clone,
{
future: ContextAwareFuture<F>,
}

impl<F, T> DropProtectedFuture<F, T>
where
T: Send + 'static,
F: Future<Output = T> + Send + 'static + Unpin + Clone,
{
pub fn new(f: F, span: Span, ctx: Option<Arc<OriginContext>>) -> Self {
Self {
future: ContextAwareFuture::new(ctx, f.instrument(span)),
}
}
}

impl<F, T> Future for DropProtectedFuture<F, T>
where
T: Send + 'static,
F: Future<Output = T> + Send + 'static + Unpin + Clone,
{
type Output = F::Output;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.future).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(output) => Poll::Ready(output),
}
}
}

impl<F, T> Drop for DropProtectedFuture<F, T>
where
T: Send + 'static,
F: Future<Output = T> + Send + 'static + Unpin + Clone,
{
fn drop(&mut self) {
spawn!("DropProtectedFuture::drop", self.future.clone());
}
}
58 changes: 36 additions & 22 deletions nativelink-util/src/evicting_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ use lru::LruCache;
use nativelink_config::stores::EvictionPolicy;
use nativelink_metric::MetricsComponent;
use serde::{Deserialize, Serialize};
use tracing::{event, Level};
use tracing::{event, info_span, Level};

use crate::drop_protected_future::DropProtectedFuture;
use crate::instant_wrapper::InstantWrapper;
use crate::metrics_utils::{Counter, CounterWithTime};
use crate::origin_context::ActiveOriginContext;

#[derive(Serialize, Deserialize, PartialEq, Debug, Clone)]
pub struct SerializedLRU<K> {
Expand Down Expand Up @@ -98,7 +100,7 @@ impl<T: LenEntry + Send + Sync> LenEntry for Arc<T> {
}

#[derive(MetricsComponent)]
struct State<K: Ord + Hash + Eq + Clone + Debug, T: LenEntry + Debug> {
struct State<K: Ord + Hash + Eq + Clone + Debug + Send, T: LenEntry + Debug> {
lru: LruCache<K, EvictionItem<T>>,
btree: Option<BTreeSet<K>>,
#[metric(help = "Total size of all items in the store")]
Expand All @@ -116,7 +118,7 @@ struct State<K: Ord + Hash + Eq + Clone + Debug, T: LenEntry + Debug> {
lifetime_inserted_bytes: Counter,
}

impl<K: Ord + Hash + Eq + Clone + Debug, T: LenEntry + Debug + Sync> State<K, T> {
impl<K: Ord + Hash + Eq + Clone + Debug + Send, T: LenEntry + Debug + Sync> State<K, T> {
/// Removes an item from the cache.
async fn remove<Q>(&mut self, key: &Q, eviction_item: &EvictionItem<T>, replaced: bool)
where
Expand Down Expand Up @@ -153,7 +155,11 @@ impl<K: Ord + Hash + Eq + Clone + Debug, T: LenEntry + Debug + Sync> State<K, T>
}

#[derive(MetricsComponent)]
pub struct EvictingMap<K: Ord + Hash + Eq + Clone + Debug, T: LenEntry + Debug, I: InstantWrapper> {
pub struct EvictingMap<
K: Ord + Hash + Eq + Clone + Debug + Send,
T: LenEntry + Debug,
I: InstantWrapper,
> {
#[metric]
state: Mutex<State<K, T>>,
anchor_time: I,
Expand All @@ -169,7 +175,7 @@ pub struct EvictingMap<K: Ord + Hash + Eq + Clone + Debug, T: LenEntry + Debug,

impl<K, T, I> EvictingMap<K, T, I>
where
K: Ord + Hash + Eq + Clone + Debug,
K: Ord + Hash + Eq + Clone + Debug + Send + Sync,
T: LenEntry + Debug + Clone + Send + Sync,
I: InstantWrapper,
{
Expand Down Expand Up @@ -403,6 +409,7 @@ where
let mut state = self.state.lock().await;
let results = self
.inner_insert_many(&mut state, [(key, data)], seconds_since_anchor)
.await
.await;
results.into_iter().next()
}
Expand All @@ -418,30 +425,37 @@ where
let state = &mut self.state.lock().await;
self.inner_insert_many(state, inserts, self.anchor_time.elapsed().as_secs() as i32)
.await
.await
}

async fn inner_insert_many(
&self,
mut state: &mut State<K, T>,
inserts: impl IntoIterator<Item = (K, T)>,
seconds_since_anchor: i32,
) -> Vec<T> {
let mut replaced_items = Vec::new();
for (key, data) in inserts.into_iter() {
let new_item_size = data.len() as u64;
let eviction_item = EvictionItem {
seconds_since_anchor,
data,
};

if let Some(old_item) = state.put(key, eviction_item).await {
replaced_items.push(old_item);
}
state.sum_store_size += new_item_size;
state.lifetime_inserted_bytes.add(new_item_size);
self.evict_items(state.deref_mut()).await;
}
replaced_items
) -> impl Future<Output = Vec<T>> + Send {
DropProtectedFuture::new(
async move {
let mut replaced_items = Vec::new();
for (key, data) in inserts.into_iter() {
let new_item_size = data.len() as u64;
let eviction_item = EvictionItem {
seconds_since_anchor,
data,
};

if let Some(old_item) = state.put(key, eviction_item).await {
replaced_items.push(old_item);
}
state.sum_store_size += new_item_size;
state.lifetime_inserted_bytes.add(new_item_size);
self.evict_items(state.deref_mut()).await;
}
replaced_items
},
info_span!("EvictingMap::inner_insert_many"),
ActiveOriginContext::get(),
)
}

pub async fn remove<Q>(&self, key: &Q) -> bool
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod chunked_stream;
pub mod common;
pub mod connection_manager;
pub mod digest_hasher;
pub mod drop_protected_future;
pub mod evicting_map;
pub mod fastcdc;
pub mod fs;
Expand Down
1 change: 1 addition & 0 deletions nativelink-util/src/origin_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ impl Drop for ContextDropGuard {

pin_project! {
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Clone)]
pub struct ContextAwareFuture<F> {
// `ManuallyDrop` is used so we can call `self.span.enter()` in the `drop()`
// of our inner future, then drop the span.
Expand Down

0 comments on commit fcc7fd5

Please sign in to comment.