Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent different workflows starting with the same ID #185

Merged
merged 5 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions dbos/_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self) -> None:
self.request: Optional["Request"] = None

self.id_assigned_for_next_workflow: str = ""
self.is_within_set_workflow_id_block: bool = False

self.parent_workflow_id: str = ""
self.parent_workflow_fid: int = -1
Expand All @@ -78,6 +79,7 @@ def create_child(self) -> DBOSContext:
rv.logger = self.logger
rv.id_assigned_for_next_workflow = self.id_assigned_for_next_workflow
self.id_assigned_for_next_workflow = ""
rv.is_within_set_workflow_id_block = self.is_within_set_workflow_id_block
rv.parent_workflow_id = self.workflow_id
rv.parent_workflow_fid = self.function_id
rv.in_recovery = self.in_recovery
Expand All @@ -95,6 +97,10 @@ def assign_workflow_id(self) -> str:
if len(self.id_assigned_for_next_workflow) > 0:
wfid = self.id_assigned_for_next_workflow
else:
if self.is_within_set_workflow_id_block:
self.logger.warning(
f"Multiple workflow started within the SetWorkflowID block. Only the first workflow will be assigned with the specified workflow ID; subsequent workflows will use a generated workflow ID."
qianl15 marked this conversation as resolved.
Show resolved Hide resolved
)
wfid = str(uuid.uuid4())
return wfid

Expand Down Expand Up @@ -286,7 +292,7 @@ def __exit__(

class SetWorkflowID:
"""
Set the workflow ID to be used for the enclosed workflow invocation.
Set the workflow ID to be used for the enclosed workflow invocation. Note: Only the first workflow will be started with the specified workflow ID within a `with SetWorkflowID` block.

Typical Usage
```
Expand All @@ -311,7 +317,9 @@ def __enter__(self) -> SetWorkflowID:
if ctx is None:
self.created_ctx = True
_set_local_dbos_context(DBOSContext())
assert_current_dbos_context().id_assigned_for_next_workflow = self.wfid
ctx = assert_current_dbos_context()
ctx.id_assigned_for_next_workflow = self.wfid
ctx.is_within_set_workflow_id_block = True
return self

def __exit__(
Expand All @@ -321,6 +329,7 @@ def __exit__(
traceback: Optional[TracebackType],
) -> Literal[False]:
# Code to clean up the basic context if we created it
assert_current_dbos_context().is_within_set_workflow_id_block = False
if self.created_ctx:
_clear_local_dbos_context()
return False # Did not handle
Expand Down
6 changes: 5 additions & 1 deletion dbos/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _init_workflow(
wf_status = dbos._sys_db.update_workflow_status(
status, False, ctx.in_recovery, max_recovery_attempts=max_recovery_attempts
)
# TODO: Modify the inputs if they were changed by `update_workflow_inputs`
dbos._sys_db.update_workflow_inputs(wfid, _serialization.serialize_args(inputs))
else:
# Buffer the inputs for single-transaction workflows, but don't buffer the status
Expand Down Expand Up @@ -422,6 +423,9 @@ def start_workflow(
or wf_status == WorkflowStatusString.ERROR.value
or wf_status == WorkflowStatusString.SUCCESS.value
):
dbos.logger.debug(
f"Workflow {new_wf_id} already completed with status {wf_status}. Directly returning a workflow handle."
)
return WorkflowHandlePolling(new_wf_id, dbos)

if fself is not None:
Expand Down Expand Up @@ -494,7 +498,7 @@ def init_wf() -> Callable[[Callable[[], R]], R]:
temp_wf_type=get_temp_workflow_type(func),
max_recovery_attempts=max_recovery_attempts,
)

# TODO: maybe modify the parameters if they've been changed by `_init_workflow`
dbos.logger.debug(
f"Running workflow, id: {ctx.workflow_id}, name: {get_dbos_func_name(func)}"
)
Expand Down
11 changes: 11 additions & 0 deletions dbos/_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class DBOSErrorCode(Enum):
DeadLetterQueueError = 6
MaxStepRetriesExceeded = 7
NotAuthorized = 8
ConflictingWorkflowError = 9


class DBOSWorkflowConflictIDError(DBOSException):
Expand All @@ -47,6 +48,16 @@ def __init__(self, workflow_id: str):
)


class DBOSConflictingWorkflowError(DBOSException):
"""Exception raised different workflows started with the same workflow ID."""

def __init__(self, workflow_id: str, message: Optional[str] = None):
super().__init__(
f"Conflicting workflow invocation with the same ID ({workflow_id}): {message}",
dbos_error_code=DBOSErrorCode.ConflictingWorkflowError.value,
)


class DBOSRecoveryError(DBOSException):
"""Exception raised when a workflow recovery fails."""

Expand Down
18 changes: 17 additions & 1 deletion dbos/_kafka.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import threading
from typing import TYPE_CHECKING, Any, Callable, NoReturn

Expand All @@ -19,6 +20,14 @@
_in_order_kafka_queues: dict[str, Queue] = {}


def safe_group_name(method_name: str, topics: list[str]) -> str:
safe_group_id = "-".join(
re.sub(r"[^a-zA-Z0-9\-]", "", str(r)) for r in [method_name, *topics]
)

return f"dbos-kafka-group-{safe_group_id}"[:255]


def _kafka_consumer_loop(
func: _KafkaConsumerWorkflow,
config: dict[str, Any],
Expand All @@ -34,6 +43,12 @@ def on_error(err: KafkaError) -> NoReturn:
if "auto.offset.reset" not in config:
config["auto.offset.reset"] = "earliest"

if config.get("group.id") is None:
config["group.id"] = safe_group_name(func.__qualname__, topics)
dbos_logger.debug(
qianl15 marked this conversation as resolved.
Show resolved Hide resolved
f"Consumer group ID not found. Using generated group.id {config['group.id']}"
)

consumer = Consumer(config)
try:
consumer.subscribe(topics)
Expand Down Expand Up @@ -71,8 +86,9 @@ def on_error(err: KafkaError) -> NoReturn:
topic=cmsg.topic(),
value=cmsg.value(),
)
groupID = config.get("group.id")
with SetWorkflowID(
f"kafka-unique-id-{msg.topic}-{msg.partition}-{msg.offset}"
f"kafka-unique-id-{msg.topic}-{msg.partition}-{groupID}-{msg.offset}"
):
if in_order:
assert msg.topic is not None
Expand Down
100 changes: 66 additions & 34 deletions dbos/_sys_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from . import _serialization
from ._dbos_config import ConfigFile
from ._error import (
DBOSConflictingWorkflowError,
DBOSDeadLetterQueueError,
DBOSException,
DBOSNonExistentWorkflowError,
Expand Down Expand Up @@ -288,46 +289,68 @@ def update_workflow_status(
),
)
else:
cmd = cmd.on_conflict_do_nothing()
cmd = cmd.returning(SystemSchema.workflow_status.c.recovery_attempts, SystemSchema.workflow_status.c.status) # type: ignore
# A blank update so that we can return the existing status
cmd = cmd.on_conflict_do_update(
index_elements=["workflow_uuid"],
set_=dict(
recovery_attempts=SystemSchema.workflow_status.c.recovery_attempts
),
)
cmd = cmd.returning(SystemSchema.workflow_status.c.recovery_attempts, SystemSchema.workflow_status.c.status, SystemSchema.workflow_status.c.name, SystemSchema.workflow_status.c.class_name, SystemSchema.workflow_status.c.config_name, SystemSchema.workflow_status.c.queue_name) # type: ignore

if conn is not None:
results = conn.execute(cmd)
else:
with self.engine.begin() as c:
results = c.execute(cmd)

if in_recovery:
row = results.fetchone()
if row is not None:
recovery_attempts: int = row[0]
wf_status = row[1]
if recovery_attempts > max_recovery_attempts:
with self.engine.begin() as c:
c.execute(
sa.delete(SystemSchema.workflow_queue).where(
SystemSchema.workflow_queue.c.workflow_uuid
== status["workflow_uuid"]
)
row = results.fetchone()
if row is not None:
# Check the started workflow matches the expected name, class_name, config_name, and queue_name
# A mismatch indicates a workflow starting with the same UUID but different functions, which would throw an exception.
recovery_attempts: int = row[0]
wf_status = row[1]
err_msg: Optional[str] = None
if row[2] != status["name"]:
err_msg = f"Workflow already exists with a different function name: {row[2]}, but the provided function name is: {status['name']}"
elif row[3] != status["class_name"]:
err_msg = f"Workflow already exists with a different class name: {row[3]}, but the provided class name is: {status['class_name']}"
elif row[4] != status["config_name"]:
err_msg = f"Workflow already exists with a different config name: {row[4]}, but the provided config name is: {status['config_name']}"
elif row[5] != status["queue_name"]:
# This is a warning because a different queue name is not necessarily an error.
dbos_logger.warning(
f"Workflow already exists in queue: {row[5]}, but the provided queue name is: {status['queue_name']}. The queue is not updated."
)
if err_msg is not None:
raise DBOSConflictingWorkflowError(status["workflow_uuid"], err_msg)

if in_recovery and recovery_attempts > max_recovery_attempts:
with self.engine.begin() as c:
c.execute(
sa.delete(SystemSchema.workflow_queue).where(
SystemSchema.workflow_queue.c.workflow_uuid
== status["workflow_uuid"]
)
c.execute(
sa.update(SystemSchema.workflow_status)
.where(
SystemSchema.workflow_status.c.workflow_uuid
== status["workflow_uuid"]
)
.where(
SystemSchema.workflow_status.c.status
== WorkflowStatusString.PENDING.value
)
.values(
status=WorkflowStatusString.RETRIES_EXCEEDED.value,
queue_name=None,
)
)
c.execute(
sa.update(SystemSchema.workflow_status)
.where(
SystemSchema.workflow_status.c.workflow_uuid
== status["workflow_uuid"]
)
.where(
SystemSchema.workflow_status.c.status
== WorkflowStatusString.PENDING.value
)
.values(
status=WorkflowStatusString.RETRIES_EXCEEDED.value,
queue_name=None,
)
raise DBOSDeadLetterQueueError(
status["workflow_uuid"], max_recovery_attempts
)
raise DBOSDeadLetterQueueError(
status["workflow_uuid"], max_recovery_attempts
)

# Record we have exported status for this single-transaction workflow
if status["workflow_uuid"] in self._temp_txn_wf_ids:
Expand Down Expand Up @@ -538,18 +561,27 @@ def update_workflow_inputs(
workflow_uuid=workflow_uuid,
inputs=inputs,
)
.on_conflict_do_nothing()
.on_conflict_do_update(
index_elements=["workflow_uuid"],
set_=dict(workflow_uuid=SystemSchema.workflow_inputs.c.workflow_uuid),
)
.returning(SystemSchema.workflow_inputs.c.inputs)
)
if conn is not None:
conn.execute(cmd)
row = conn.execute(cmd).fetchone()
else:
with self.engine.begin() as c:
c.execute(cmd)

row = c.execute(cmd).fetchone()
if row is not None and row[0] != inputs:
dbos_logger.warning(
f"Workflow inputs for {workflow_uuid} changed since the first call! Use the original inputs."
)
# TODO: actually changing the input
if workflow_uuid in self._temp_txn_wf_ids:
# Clean up the single-transaction tracking sets
self._exported_temp_txn_wf_status.discard(workflow_uuid)
self._temp_txn_wf_ids.discard(workflow_uuid)
return

def get_workflow_inputs(
self, workflow_uuid: str
Expand Down
9 changes: 6 additions & 3 deletions tests/test_dbos.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,12 @@ def exception_workflow() -> None:
assert bad_txn_counter == 2 # Only increment once

# Test we can execute the workflow by uuid, shouldn't throw errors
dbos._sys_db._flush_workflow_status_buffer()
handle = DBOS.execute_workflow_id(wfuuid)
with pytest.raises(Exception) as exc_info:
handle.get_result()
assert "test error" == str(exc_info.value)
assert wf_counter == 4
assert wf_counter == 3 # The workflow error is directly returned without running


def test_temp_workflow(dbos: DBOS) -> None:
Expand Down Expand Up @@ -1185,6 +1186,8 @@ def test_workflow_dest() -> str:
assert "Running recv" in caplog.text
caplog.clear()

dbos._sys_db._flush_workflow_status_buffer()

# Second run
with SetWorkflowID(dest_wfid):
dest_handle_2 = dbos.start_workflow(test_workflow_dest)
Expand All @@ -1204,8 +1207,8 @@ def test_workflow_dest() -> str:

result4 = dest_handle_2.get_result()
assert result4 == result2
assert "Replaying get_event" in caplog.text
assert "Replaying recv" in caplog.text
# In start_workflow, we skip the replay of already finished workflows
assert f"Workflow {dest_wfid} already completed with status" in caplog.text

# Reset logging
logging.getLogger("dbos").propagate = original_propagate
38 changes: 37 additions & 1 deletion tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,42 @@ def test_kafka_workflow(msg: KafkaMessage) -> None:
if kafka_count == 3:
event.set()

wait = event.wait(timeout=10)
wait = event.wait(timeout=15)
assert wait
assert kafka_count == 3
time.sleep(2) # Wait for things to clean up


def test_kafka_no_groupid(dbos: DBOS) -> None:
event = threading.Event()
kafka_count = 0
server = "localhost:9092"
topic1 = f"dbos-kafka-{random.randrange(1_000_000_000, 2_000_000_000)}"
topic2 = f"dbos-kafka-{random.randrange(2_000_000_000, 3_000_000_000)}"

if not send_test_messages(server, topic1):
pytest.skip("Kafka not available")

if not send_test_messages(server, topic2):
pytest.skip("Kafka not available")

@DBOS.kafka_consumer(
{
"bootstrap.servers": server,
"auto.offset.reset": "earliest",
},
[topic1, topic2],
)
@DBOS.workflow()
def test_kafka_workflow(msg: KafkaMessage) -> None:
nonlocal kafka_count
kafka_count += 1
assert b"test message key" in msg.key # type: ignore
assert b"test message value" in msg.value # type: ignore
print(msg)
if kafka_count == 6:
event.set()

wait = event.wait(timeout=10)
assert wait
assert kafka_count == 6
Loading
Loading