Skip to content

Commit

Permalink
use priority queue in scheduler to ensure loop conditions are met
Browse files Browse the repository at this point in the history
probably temporary
  • Loading branch information
MingweiSamuel committed Jan 15, 2025
1 parent 2f04fa6 commit a47a92a
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
5 changes: 3 additions & 2 deletions dfir_rs/src/scheduled/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Provides APIs for state and scheduling.
use std::any::Any;
use std::collections::VecDeque;
use std::collections::BinaryHeap;
use std::future::Future;
use std::marker::PhantomData;
use std::ops::DerefMut;
Expand All @@ -28,7 +28,8 @@ pub struct Context {

/// TODO(mingwei): separate scheduler into its own struct/trait?
/// Index is stratum, value is FIFO queue for that stratum.
pub(super) stratum_queues: Vec<VecDeque<SubgraphId>>,
/// PriorityQueue, usize is depth. Larger/deeper is higher priority.
pub(super) stratum_queues: Vec<BinaryHeap<(usize, SubgraphId)>>,
/// Receive events, if second arg indicates if it is an external "important" event (true).
pub(super) event_queue_recv: UnboundedReceiver<(SubgraphId, bool)>,
/// If external events or data can justify starting the next tick.
Expand Down
33 changes: 21 additions & 12 deletions dfir_rs/src/scheduled/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@ impl<'a> Dfir<'a> {

let mut work_done = false;

while let Some(sg_id) =
self.context.stratum_queues[self.context.current_stratum].pop_front()
while let Some((_depth, sg_id)) =
self.context.stratum_queues[self.context.current_stratum].pop()
{
work_done = true;
{
Expand Down Expand Up @@ -303,7 +303,8 @@ impl<'a> Dfir<'a> {
}
// Add subgraph to stratum queue if it is not already scheduled.
if !succ_sg_data.is_scheduled.replace(true) {
self.context.stratum_queues[succ_sg_data.stratum].push_back(succ_id);
self.context.stratum_queues[succ_sg_data.stratum]
.push((succ_sg_data.loop_depth, succ_id));
}
}
}
Expand Down Expand Up @@ -456,7 +457,7 @@ impl<'a> Dfir<'a> {
"Event received."
);
if !sg_data.is_scheduled.replace(true) {
self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
self.context.stratum_queues[sg_data.stratum].push((sg_data.loop_depth, sg_id));
enqueued_count += 1;
}
if is_external {
Expand Down Expand Up @@ -494,7 +495,7 @@ impl<'a> Dfir<'a> {
"Event received."
);
if !sg_data.is_scheduled.replace(true) {
self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
self.context.stratum_queues[sg_data.stratum].push((sg_data.loop_depth, sg_id));
count += 1;
}
if is_external {
Expand Down Expand Up @@ -539,7 +540,7 @@ impl<'a> Dfir<'a> {
"Event received."
);
if !sg_data.is_scheduled.replace(true) {
self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
self.context.stratum_queues[sg_data.stratum].push((sg_data.loop_depth, sg_id));
count += 1;
}
if is_external {
Expand Down Expand Up @@ -570,7 +571,7 @@ impl<'a> Dfir<'a> {
let sg_data = &self.subgraphs[sg_id];
let already_scheduled = sg_data.is_scheduled.replace(true);
if !already_scheduled {
self.context.stratum_queues[sg_data.stratum].push_back(sg_id);
self.context.stratum_queues[sg_data.stratum].push((sg_data.loop_depth, sg_id));
true
} else {
false
Expand Down Expand Up @@ -614,6 +615,8 @@ impl<'a> Dfir<'a> {
W: 'static + PortList<SEND>,
F: 'a + for<'ctx> FnMut(&'ctx mut Context, R::Ctx<'ctx>, W::Ctx<'ctx>),
{
let loop_depth = loop_id.map_or(0, |loop_id| self.loop_depth[loop_id]);

let sg_id = self.subgraphs.insert_with_key(|sg_id| {
let (mut subgraph_preds, mut subgraph_succs) = Default::default();
recv_ports.set_graph_meta(&mut self.handoffs, &mut subgraph_preds, sg_id, true);
Expand All @@ -634,10 +637,11 @@ impl<'a> Dfir<'a> {
true,
laziness,
loop_id,
loop_depth,
)
});
self.context.init_stratum(stratum);
self.context.stratum_queues[stratum].push_back(sg_id);
self.context.stratum_queues[stratum].push((loop_depth, sg_id));

sg_id
}
Expand Down Expand Up @@ -728,11 +732,12 @@ impl<'a> Dfir<'a> {
true,
false,
None,
0,
)
});

self.context.init_stratum(stratum);
self.context.stratum_queues[stratum].push_back(sg_id);
self.context.stratum_queues[stratum].push((0, sg_id));

sg_id
}
Expand Down Expand Up @@ -916,6 +921,8 @@ pub(super) struct SubgraphData<'a> {
/// The subgraph's loop ID, or `None` for the top level.
#[expect(dead_code, reason = "TODO(mingwei): WIP")]
loop_id: Option<LoopId>,
/// The subgraph's loop depth.
loop_depth: usize,
}
impl<'a> SubgraphData<'a> {
#[expect(clippy::too_many_arguments, reason = "internal use")]
Expand All @@ -926,19 +933,21 @@ impl<'a> SubgraphData<'a> {
preds: Vec<HandoffId>,
succs: Vec<HandoffId>,
is_scheduled: bool,
laziness: bool,
is_lazy: bool,
loop_id: Option<LoopId>,
loop_depth: usize,
) -> Self {
Self {
name,
stratum,
subgraph: Box::new(subgraph),
preds,
succs,
loop_id,
is_scheduled: Cell::new(is_scheduled),
last_tick_run_in: None,
is_lazy: laziness,
is_lazy,
loop_id,
loop_depth,
}
}
}
10 changes: 10 additions & 0 deletions dfir_rs/src/util/slot_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ impl<Tag: ?Sized> Clone for Key<Tag> {
}
}
impl<Tag: ?Sized> Copy for Key<Tag> {}
impl<Tag: ?Sized> PartialOrd for Key<Tag> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<Tag: ?Sized> Ord for Key<Tag> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.index.cmp(&other.index)
}
}
impl<Tag: ?Sized> PartialEq for Key<Tag> {
fn eq(&self, other: &Self) -> bool {
self.index == other.index
Expand Down

0 comments on commit a47a92a

Please sign in to comment.