Skip to content

Commit

Permalink
chore(workflows): move wf gc and metrics publish into worker
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jan 27, 2025
1 parent 541aa12 commit f6f8a7b
Show file tree
Hide file tree
Showing 19 changed files with 689 additions and 378 deletions.
11 changes: 4 additions & 7 deletions Cargo.toml

Large diffs are not rendered by default.

249 changes: 249 additions & 0 deletions packages/common/chirp-workflow/core/src/db/crdb_nats.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Implementation of a workflow database driver with PostgreSQL (CockroachDB) and NATS.
use std::{
collections::HashSet,
sync::Arc,
time::{Duration, Instant},
};
Expand Down Expand Up @@ -34,6 +35,12 @@ const QUERY_RETRY_MS: usize = 500;
const TXN_RETRY: Duration = Duration::from_millis(100);
/// Maximum times a query ran by this database adapter is retried.
const MAX_QUERY_RETRIES: usize = 16;
/// How long before considering the leases of a given worker instance "expired".
const WORKER_INSTANCE_EXPIRED_THRESHOLD_MS: i64 = rivet_util::duration::seconds(30);
/// How long before overwriting an existing GC lock.
const GC_LOCK_TIMEOUT_MS: i64 = rivet_util::duration::seconds(30);
/// How long before overwriting an existing metrics lock.
const METRICS_LOCK_TIMEOUT_MS: i64 = GC_LOCK_TIMEOUT_MS;
/// For SQL macros.
const CONTEXT_NAME: &str = "chirp_workflow_crdb_nats_engine";
/// For NATS wake mechanism.
Expand Down Expand Up @@ -159,6 +166,248 @@ impl Database for DatabaseCrdbNats {
}
}

async fn clear_expired_leases(&self, worker_instance_id: Uuid) -> WorkflowResult<()> {
let acquired_lock = sql_fetch_optional!(
[self, (i64,)]
"
UPDATE db_workflow.workflow_gc
SET
worker_instance_id = $1,
lock_ts = $2
WHERE lock_ts IS NULL OR lock_ts < $2 - $3
RETURNING 1
",
worker_instance_id,
rivet_util::timestamp::now(),
GC_LOCK_TIMEOUT_MS,
)
.await?
.is_some();

if acquired_lock {
// Reset all workflows on worker instances that have not had a ping in the last 30 seconds
let rows = sql_fetch_all!(
[self, (Uuid, Uuid,)]
"
UPDATE db_workflow.workflows AS w
SET
worker_instance_id = NULL,
wake_immediate = true,
wake_deadline_ts = NULL,
wake_signals = ARRAY[],
wake_sub_workflow_id = NULL
FROM db_workflow.worker_instances AS wi
WHERE
wi.last_ping_ts < $1 AND
wi.worker_instance_id = w.worker_instance_id AND
w.output IS NULL AND
w.silence_ts IS NULL AND
-- Check for any wake condition so we don't restart a permanently dead workflow
(
w.wake_immediate OR
w.wake_deadline_ts IS NOT NULL OR
cardinality(w.wake_signals) > 0 OR
w.wake_sub_workflow_id IS NOT NULL
)
RETURNING w.workflow_id, wi.worker_instance_id
",
rivet_util::timestamp::now() - WORKER_INSTANCE_EXPIRED_THRESHOLD_MS,
)
.await?;

if !rows.is_empty() {
let unique_worker_instance_ids = rows
.iter()
.map(|(_, worker_instance_id)| worker_instance_id)
.collect::<HashSet<_>>();

tracing::info!(
worker_instance_ids=?unique_worker_instance_ids,
total_workflows=%rows.len(),
"handled failover",
);
}

// Clear lock
sql_execute!(
[self]
"
UPDATE db_workflow.workflow_gc
SET
worker_instance_id = NULL,
lock_ts = NULL
WHERE worker_instance_id = $1
",
worker_instance_id,
)
.await?;
}

Ok(())
}

async fn publish_metrics(&self, worker_instance_id: Uuid) -> WorkflowResult<()> {
// Always update ping
metrics::WORKER_LAST_PING
.with_label_values(&[&worker_instance_id.to_string()])
.set(rivet_util::timestamp::now());

let acquired_lock = sql_fetch_optional!(
[self, (i64,)]
"
UPDATE db_workflow.workflow_metrics
SET
worker_instance_id = $1,
lock_ts = $2
WHERE lock_ts IS NULL OR lock_ts < $2 - $3
RETURNING 1
",
worker_instance_id,
rivet_util::timestamp::now(),
METRICS_LOCK_TIMEOUT_MS,
)
.await?
.is_some();

if acquired_lock {
let (
total_workflow_count,
active_workflow_count,
dead_workflow_count,
sleeping_workflow_count,
pending_signal_count,
) = tokio::try_join!(
sql_fetch_all!(
[self, (String, i64)]
"
SELECT workflow_name, COUNT(*)
FROM db_workflow.workflows AS OF SYSTEM TIME '-1s'
GROUP BY workflow_name
",
),
sql_fetch_all!(
[self, (String, i64)]
"
SELECT workflow_name, COUNT(*)
FROM db_workflow.workflows AS OF SYSTEM TIME '-1s'
WHERE
output IS NULL AND
worker_instance_id IS NOT NULL AND
silence_ts IS NULL
GROUP BY workflow_name
",
),
sql_fetch_all!(
[self, (String, String, i64)]
"
SELECT workflow_name, error, COUNT(*)
FROM db_workflow.workflows AS OF SYSTEM TIME '-1s'
WHERE
error IS NOT NULL AND
output IS NULL AND
silence_ts IS NULL AND
wake_immediate = FALSE AND
wake_deadline_ts IS NULL AND
cardinality(wake_signals) = 0 AND
wake_sub_workflow_id IS NULL
GROUP BY workflow_name, error
",
),
sql_fetch_all!(
[self, (String, i64)]
"
SELECT workflow_name, COUNT(*)
FROM db_workflow.workflows AS OF SYSTEM TIME '-1s'
WHERE
worker_instance_id IS NULL AND
output IS NULL AND
silence_ts IS NULL AND
(
wake_immediate OR
wake_deadline_ts IS NOT NULL OR
cardinality(wake_signals) > 0 OR
wake_sub_workflow_id IS NOT NULL
)
GROUP BY workflow_name
",
),
sql_fetch_all!(
[self, (String, i64)]
"
SELECT signal_name, COUNT(*)
FROM (
SELECT signal_name
FROM db_workflow.signals
WHERE
ack_ts IS NULL AND
silence_ts IS NULL
UNION ALL
SELECT signal_name
FROM db_workflow.tagged_signals
WHERE
ack_ts IS NULL AND
silence_ts IS NULL
) AS OF SYSTEM TIME '-1s'
GROUP BY signal_name
",
),
)?;

// Get rid of metrics that don't exist in the db anymore (declarative)
metrics::WORKFLOW_TOTAL.reset();
metrics::WORKFLOW_ACTIVE.reset();
metrics::WORKFLOW_DEAD.reset();
metrics::WORKFLOW_SLEEPING.reset();
metrics::SIGNAL_PENDING.reset();

for (workflow_name, count) in total_workflow_count {
metrics::WORKFLOW_TOTAL
.with_label_values(&[&workflow_name])
.set(count);
}

for (workflow_name, count) in active_workflow_count {
metrics::WORKFLOW_ACTIVE
.with_label_values(&[&workflow_name])
.set(count);
}

for (workflow_name, error, count) in dead_workflow_count {
metrics::WORKFLOW_DEAD
.with_label_values(&[&workflow_name, &error])
.set(count);
}

for (workflow_name, count) in sleeping_workflow_count {
metrics::WORKFLOW_SLEEPING
.with_label_values(&[&workflow_name])
.set(count);
}

for (signal_name, count) in pending_signal_count {
metrics::SIGNAL_PENDING
.with_label_values(&[&signal_name])
.set(count);
}

// Clear lock
sql_execute!(
[self]
"
UPDATE db_workflow.workflow_metrics
SET
worker_instance_id = NULL,
lock_ts = NULL
WHERE worker_instance_id = $1
",
worker_instance_id,
)
.await?;
}

Ok(())
}

async fn dispatch_workflow(
&self,
ray_id: Uuid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use foundationdb::future::FdbValue;

pub mod signal;
pub mod wake;
pub mod worker_instance;
pub mod workflow;

pub trait FormalKey {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use std::{borrow::Cow, result::Result::Ok};

// TODO: Use concrete error types
use anyhow::*;
use foundationdb::tuple::{PackResult, TupleDepth, TuplePack, TupleUnpack, VersionstampOffset};
use uuid::Uuid;

use super::FormalKey;

#[derive(Debug)]
pub struct LastPingTsKey {
worker_instance_id: Uuid,
}

impl LastPingTsKey {
pub fn new(worker_instance_id: Uuid) -> Self {
LastPingTsKey { worker_instance_id }
}
}

impl FormalKey for LastPingTsKey {
// Timestamp.
type Value = i64;

fn deserialize(&self, raw: &[u8]) -> Result<Self::Value> {
Ok(i64::from_be_bytes(raw.try_into()?))
}

fn serialize(&self, value: Self::Value) -> Result<Vec<u8>> {
Ok(value.to_be_bytes().to_vec())
}
}

impl TuplePack for LastPingTsKey {
fn pack<W: std::io::Write>(
&self,
w: &mut W,
tuple_depth: TupleDepth,
) -> std::io::Result<VersionstampOffset> {
let t = (
"worker_instance",
"data",
self.worker_instance_id,
"last_ping_ts",
);
t.pack(w, tuple_depth)
}
}

impl<'de> TupleUnpack<'de> for LastPingTsKey {
fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> {
let (input, (_, _, worker_instance_id, _)) =
<(Cow<str>, Cow<str>, Uuid, Cow<str>)>::unpack(input, tuple_depth)?;
let v = LastPingTsKey { worker_instance_id };

Ok((input, v))
}
}

#[derive(Debug)]
pub struct MetricsLockKey {}

impl MetricsLockKey {
pub fn new() -> Self {
MetricsLockKey {}
}
}

impl FormalKey for MetricsLockKey {
// Timestamp.
type Value = i64;

fn deserialize(&self, raw: &[u8]) -> Result<Self::Value> {
Ok(i64::from_be_bytes(raw.try_into()?))
}

fn serialize(&self, value: Self::Value) -> Result<Vec<u8>> {
Ok(value.to_be_bytes().to_vec())
}
}

impl TuplePack for MetricsLockKey {
fn pack<W: std::io::Write>(
&self,
w: &mut W,
tuple_depth: TupleDepth,
) -> std::io::Result<VersionstampOffset> {
let t = ("worker_instance", "metrics_lock");
t.pack(w, tuple_depth)
}
}

impl<'de> TupleUnpack<'de> for MetricsLockKey {
fn unpack(input: &[u8], tuple_depth: TupleDepth) -> PackResult<(&[u8], Self)> {
let (input, (_, _)) = <(Cow<str>, Cow<str>)>::unpack(input, tuple_depth)?;
let v = MetricsLockKey {};

Ok((input, v))
}
}
Loading

0 comments on commit f6f8a7b

Please sign in to comment.