Skip to content

Commit

Permalink
feat(workflows): add loop state
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterPtato committed Jan 24, 2025
1 parent 057e98b commit 9412633
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 63 deletions.
113 changes: 54 additions & 59 deletions packages/common/chirp-workflow/core/src/ctx/workflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -715,45 +715,6 @@ impl WorkflowCtx {
Ok(signal)
}

// TODO: Currently implemented wrong, if no signal is received it should still write a signal row to the
// database so that upon replay it again receives no signal
// /// Checks if the given signal exists in the database.
// pub async fn query_signal<T: Listen>(&mut self) -> GlobalResult<Option<T>> {
// let event = self.current_history_event();

// // Signal received before
// let signal = if let Some(event) = event {
// tracing::debug!(name=%self.name, id=%self.workflow_id, "replaying signal");

// // Validate history is consistent
// let Event::Signal(signal) = event else {
// return Err(WorkflowError::HistoryDiverged(format!(
// "expected {event} at {}, found signal",
// self.loc(),
// )))
// .map_err(GlobalError::raw);
// };

// Some(T::parse(&signal.name, signal.body.clone()).map_err(GlobalError::raw)?)
// }
// // Listen for new message
// else {
// let mut ctx = ListenCtx::new(self);
// ctx.reset();

// match T::listen(&mut ctx).await {
// Ok(res) => Some(res),
// Err(err) if matches!(err, WorkflowError::NoSignalFound(_)) => None,
// Err(err) => return Err(err).map_err(GlobalError::raw),
// }
// };

// // Move to next event
// self.cursor.update();

// Ok(signal)
// }

/// Creates a message builder.
pub fn msg<M>(&mut self, body: M) -> builder::message::MessageBuilder<M>
where
Expand All @@ -762,12 +723,30 @@ impl WorkflowCtx {
builder::message::MessageBuilder::new(self, self.version, body)
}

/// Runs workflow steps in a loop. **Ensure that there are no side effects caused by the code in this
/// callback**. If you need side causes or side effects, use a native rust loop.
/// Runs workflow steps in a loop. If you need side causes, use `WorkflowCtx::loope`.
pub async fn repeat<F, T>(&mut self, mut cb: F) -> GlobalResult<T>
where
F: for<'a> FnMut(&'a mut WorkflowCtx) -> AsyncResult<'a, Loop<T>>,
T: Serialize + DeserializeOwned,
{
self.loop_inner((), |ctx, _| cb(ctx)).await
}

/// Runs workflow steps in a loop with state.
pub async fn loope<S, F, T>(&mut self, state: S, cb: F) -> GlobalResult<T>
where
S: Serialize + DeserializeOwned,
F: for<'a> FnMut(&'a mut WorkflowCtx, &'a mut S) -> AsyncResult<'a, Loop<T>>,
T: Serialize + DeserializeOwned,
{
self.loop_inner(state, cb).await
}

async fn loop_inner<S, F, T>(&mut self, state: S, mut cb: F) -> GlobalResult<T>
where
S: Serialize + DeserializeOwned,
F: for<'a> FnMut(&'a mut WorkflowCtx, &'a mut S) -> AsyncResult<'a, Loop<T>>,
T: Serialize + DeserializeOwned,
{
let history_res = self
.cursor
Expand All @@ -776,25 +755,32 @@ impl WorkflowCtx {
let loop_location = self.cursor.current_location_for(&history_res);

// Loop existed before
let (mut iteration, output) = if let HistoryResult::Event(loop_event) = history_res {
let output = loop_event.parse_output().map_err(GlobalError::raw)?;
let (mut iteration, mut state, output) =
if let HistoryResult::Event(loop_event) = history_res {
let state = loop_event.parse_state().map_err(GlobalError::raw)?;
let output = loop_event.parse_output().map_err(GlobalError::raw)?;

(loop_event.iteration, output)
} else {
// Insert event before loop is run so the history is consistent
self.db
.upsert_workflow_loop_event(
self.workflow_id,
&loop_location,
self.version,
0,
None,
self.loop_location(),
)
.await?;
(loop_event.iteration, state, output)
} else {
let state_val = serde_json::value::to_raw_value(&state)
.map_err(WorkflowError::SerializeLoopOutput)
.map_err(GlobalError::raw)?;

(0, None)
};
// Insert event before loop is run so the history is consistent
self.db
.upsert_workflow_loop_event(
self.workflow_id,
&loop_location,
self.version,
0,
&state_val,
None,
self.loop_location(),
)
.await?;

(0, state, None)
};

// Create a branch but no branch event (loop event takes its place)
let mut loop_branch =
Expand Down Expand Up @@ -841,16 +827,21 @@ impl WorkflowCtx {
}

// Run loop
match cb(&mut iteration_branch).await? {
match cb(&mut iteration_branch, &mut state).await? {
Loop::Continue => {
iteration += 1;

let state_val = serde_json::value::to_raw_value(&state)
.map_err(WorkflowError::SerializeLoopOutput)
.map_err(GlobalError::raw)?;

self.db
.upsert_workflow_loop_event(
self.workflow_id,
&loop_location,
self.version,
iteration,
&state_val,
None,
self.loop_location(),
)
Expand All @@ -865,6 +856,9 @@ impl WorkflowCtx {
Loop::Break(res) => {
iteration += 1;

let state_val = serde_json::value::to_raw_value(&state)
.map_err(WorkflowError::SerializeLoopOutput)
.map_err(GlobalError::raw)?;
let output_val = serde_json::value::to_raw_value(&res)
.map_err(WorkflowError::SerializeLoopOutput)
.map_err(GlobalError::raw)?;
Expand All @@ -875,6 +869,7 @@ impl WorkflowCtx {
&loop_location,
self.version,
iteration,
&state_val,
Some(&output_val),
self.loop_location(),
)
Expand Down
20 changes: 18 additions & 2 deletions packages/common/chirp-workflow/core/src/db/crdb_nats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ impl Database for DatabaseCrdbNats {
activity_name AS name,
NULL AS auxiliary_id,
input_hash AS hash,
NULL AS input,
output AS output,
create_ts AS create_ts,
(
Expand Down Expand Up @@ -450,6 +451,7 @@ impl Database for DatabaseCrdbNats {
signal_name AS name,
NULL AS auxiliary_id,
NULL AS hash,
NULL AS input,
body AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -470,6 +472,7 @@ impl Database for DatabaseCrdbNats {
signal_name AS name,
signal_id AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -490,6 +493,7 @@ impl Database for DatabaseCrdbNats {
message_name AS name,
NULL AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -510,6 +514,7 @@ impl Database for DatabaseCrdbNats {
w.workflow_name AS name,
sw.sub_workflow_id AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -532,6 +537,7 @@ impl Database for DatabaseCrdbNats {
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
state AS input,
output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -552,6 +558,7 @@ impl Database for DatabaseCrdbNats {
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -572,6 +579,7 @@ impl Database for DatabaseCrdbNats {
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -592,6 +600,7 @@ impl Database for DatabaseCrdbNats {
event_name AS name,
NULL AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand All @@ -612,6 +621,7 @@ impl Database for DatabaseCrdbNats {
NULL AS name,
NULL AS auxiliary_id,
NULL AS hash,
NULL AS input,
NULL AS output,
NULL AS create_ts,
NULL AS error_count,
Expand Down Expand Up @@ -1250,6 +1260,7 @@ impl Database for DatabaseCrdbNats {
location: &Location,
version: usize,
iteration: usize,
state: &serde_json::value::RawValue,
output: Option<&serde_json::value::RawValue>,
loop_location: Option<&Location>,
) -> WorkflowResult<()> {
Expand All @@ -1265,20 +1276,23 @@ impl Database for DatabaseCrdbNats {
location2,
version,
iteration,
state,
output,
loop_location2
)
VALUES ($1, $2, $3, $4, $5, $6)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (workflow_id, location2_hash) DO UPDATE
SET
iteration = $4,
output = $5
state = $5,
output = $6
RETURNING 1
",
workflow_id,
location,
version as i64,
iteration as i64,
sqlx::types::Json(state),
output.map(sqlx::types::Json),
loop_location,
)
Expand Down Expand Up @@ -1626,6 +1640,7 @@ mod types {
name: Option<String>,
auxiliary_id: Option<Uuid>,
hash: Option<Vec<u8>>,
input: Option<RawJson>,
output: Option<RawJson>,
create_ts: Option<i64>,
error_count: Option<i64>,
Expand Down Expand Up @@ -1751,6 +1766,7 @@ mod types {

fn try_from(value: AmalgamEventRow) -> WorkflowResult<Self> {
Ok(LoopEvent {
state: value.input.ok_or(WorkflowError::MissingEventData)?.0,
output: value.output.map(|x| x.0),
iteration: value
.iteration
Expand Down
1 change: 1 addition & 0 deletions packages/common/chirp-workflow/core/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ pub trait Database: Send {
location: &Location,
version: usize,
iteration: usize,
state: &serde_json::value::RawValue,
output: Option<&serde_json::value::RawValue>,
loop_location: Option<&Location>,
) -> WorkflowResult<()>;
Expand Down
6 changes: 6 additions & 0 deletions packages/common/chirp-workflow/core/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ pub enum WorkflowError {
#[error("tags must be a json object")]
InvalidTags,

#[error("failed to serialize loop state: {0}")]
SerializeLoopState(serde_json::Error),

#[error("failed to deserialize loop state: {0}")]
DeserializeLoopState(serde_json::Error),

#[error("failed to serialize loop output: {0}")]
SerializeLoopOutput(serde_json::Error),

Expand Down
5 changes: 5 additions & 0 deletions packages/common/chirp-workflow/core/src/history/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,17 @@ pub struct SubWorkflowEvent {

#[derive(Debug)]
pub struct LoopEvent {
pub(crate) state: Box<serde_json::value::RawValue>,
/// If the loop completes, this will be some.
pub(crate) output: Option<Box<serde_json::value::RawValue>>,
pub iteration: usize,
}

impl LoopEvent {
pub fn parse_state<S: DeserializeOwned>(&self) -> WorkflowResult<S> {
serde_json::from_str(self.state.get()).map_err(WorkflowError::DeserializeLoopState)
}

pub fn parse_output<O: DeserializeOwned>(&self) -> WorkflowResult<Option<O>> {
self.output
.as_ref()
Expand Down
4 changes: 2 additions & 2 deletions packages/infra/server/src/util/wf/history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ pub struct SubWorkflowEvent {
pub name: String,
pub tags: serde_json::Value,
pub input: serde_json::Value,
pub output: Option<serde_json::Value>,
// pub output: Option<serde_json::Value>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -296,7 +296,7 @@ impl TryFrom<AmalgamEventRow> for SubWorkflowEvent {
name: value.name.context("missing event data")?,
tags: value.tags.context("missing event data")?,
input: value.input.context("missing event data")?,
output: value.output,
// output: value.output,
})
}
}
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE workflow_loop_events
ADD COLUMN state JSONB NOT NULL DEFAULT 'null';

0 comments on commit 9412633

Please sign in to comment.