Skip to content

Commit

Permalink
feat(engine): Disable Temporal search attributes if in cloud
Browse files Browse the repository at this point in the history
  • Loading branch information
daryllimyt committed Jan 21, 2025
1 parent 927660e commit a07f599
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 15 deletions.
10 changes: 7 additions & 3 deletions alembic/versions/db3c91261770_add_temporal_search_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
wait_exponential,
)

from tracecat.config import TEMPORAL__API_KEY__ARN, TEMPORAL__CLUSTER_NAMESPACE
from tracecat.config import (
TEMPORAL__API_KEY,
TEMPORAL__API_KEY__ARN,
TEMPORAL__CLUSTER_NAMESPACE,
)
from tracecat.dsl.client import get_temporal_client
from tracecat.logger import logger

Expand Down Expand Up @@ -72,7 +76,7 @@ async def remove_temporal_search_attributes():

def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
if TEMPORAL__API_KEY__ARN or os.environ.get("TEMPORAL__API_KEY"):
if TEMPORAL__API_KEY__ARN or TEMPORAL__API_KEY:
logger.info("Using Temporal cloud, skipping upgrade (add search attributes)")
else:
asyncio.run(add_temporal_search_attributes())
Expand All @@ -81,7 +85,7 @@ def upgrade() -> None:

def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
if TEMPORAL__API_KEY__ARN or os.environ.get("TEMPORAL__API_KEY"):
if TEMPORAL__API_KEY__ARN or TEMPORAL__API_KEY:
logger.info(
"Using Temporal cloud, skipping downgrade (remove search attributes)"
)
Expand Down
1 change: 1 addition & 0 deletions tracecat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
"TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue"
)
TEMPORAL__API_KEY__ARN = os.environ.get("TEMPORAL__API_KEY__ARN")
TEMPORAL__API_KEY = os.environ.get("TEMPORAL__API_KEY")
TEMPORAL__MTLS_ENABLED = os.environ.get("TEMPORAL__MTLS_ENABLED", "").lower() in (
"1",
"true",
Expand Down
26 changes: 16 additions & 10 deletions tracecat/workflow/executions/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,16 +525,22 @@ async def _dispatch_workflow(
run_config=dsl.config,
kwargs=kwargs,
)
search_attrs = [
trigger_type.to_temporal_search_attr_pair(),
]
if self.role.user_id is not None:
search_attrs.append(
SearchAttributePair(
key=SearchAttributeKey.for_keyword("TracecatTriggeredByUserId"),
value=str(self.role.user_id),

if config.TEMPORAL__API_KEY__ARN or config.TEMPORAL__API_KEY:
self.logger.warning("Using Temporal cloud, skipping search attributes")
search_attrs = None
else:
pairs = [
trigger_type.to_temporal_search_attr_pair(),
]
if self.role.user_id is not None:
pairs.append(
SearchAttributePair(
key=SearchAttributeKey.for_keyword("TracecatTriggeredByUserId"),
value=str(self.role.user_id),
)
)
)
search_attrs = TypedSearchAttributes(search_attributes=pairs)
try:
result = await self._client.execute_workflow(
DSLWorkflow.run,
Expand All @@ -546,7 +552,7 @@ async def _dispatch_workflow(
retry_policy=retry_policies["workflow:fail_fast"],
# We don't currently differentiate between exec and run timeout as we fail fast for workflows
execution_timeout=datetime.timedelta(seconds=dsl.config.timeout),
search_attributes=TypedSearchAttributes(search_attributes=search_attrs),
search_attributes=search_attrs,
**kwargs,
)
except WorkflowFailureError as e:
Expand Down
17 changes: 15 additions & 2 deletions tracecat/workflow/schedules/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
from tracecat.dsl.client import get_temporal_client
from tracecat.dsl.common import DSLRunArgs
from tracecat.identifiers import ScheduleID, WorkflowID
from tracecat.logger import logger
from tracecat.workflow.executions.enums import TriggerType
from tracecat.workflow.schedules.models import ScheduleUpdate

search_attrs = TypedSearchAttributes(
SEARCH_ATTRS = TypedSearchAttributes(
search_attributes=[TriggerType.SCHEDULED.to_temporal_search_attr_pair()]
)

Expand Down Expand Up @@ -44,6 +45,13 @@ async def create_schedule(

workflow_schedule_id = f"{workflow_id}:{schedule_id}"

if config.TEMPORAL__API_KEY__ARN or config.TEMPORAL__API_KEY:
logger.warning(
"Using Temporal cloud, skipping search attributes (add to schedule)"
)
search_attrs = TypedSearchAttributes.empty
else:
search_attrs = SEARCH_ATTRS
return await client.create_schedule(
id=schedule_id,
schedule=temporalio.client.Schedule(
Expand Down Expand Up @@ -97,7 +105,12 @@ async def _update_schedule(
if "status" in set_fields:
state.paused = set_fields["status"] != "online"
if isinstance(action, temporalio.client.ScheduleActionStartWorkflow):
action.typed_search_attributes = search_attrs
if config.TEMPORAL__API_KEY__ARN or config.TEMPORAL__API_KEY:
logger.warning(
"Using Temporal cloud, skipping search attributes (update schedule)"
)
else:
action.typed_search_attributes = SEARCH_ATTRS
if "inputs" in set_fields:
action.args[0].dsl.trigger_inputs = set_fields["inputs"] # type: ignore
else:
Expand Down

0 comments on commit a07f599

Please sign in to comment.