From a07f599906f843d05525df2caeebb4892efb3db2 Mon Sep 17 00:00:00 2001 From: Daryl Lim <5508348+daryllimyt@users.noreply.github.com> Date: Tue, 21 Jan 2025 01:26:49 +0000 Subject: [PATCH] feat(engine): Disable Temporal search attributes if in cloud --- ...91261770_add_temporal_search_attributes.py | 10 ++++--- tracecat/config.py | 1 + tracecat/workflow/executions/service.py | 26 ++++++++++++------- tracecat/workflow/schedules/bridge.py | 17 ++++++++++-- 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/alembic/versions/db3c91261770_add_temporal_search_attributes.py b/alembic/versions/db3c91261770_add_temporal_search_attributes.py index 76a8ccfd4..2cb152029 100644 --- a/alembic/versions/db3c91261770_add_temporal_search_attributes.py +++ b/alembic/versions/db3c91261770_add_temporal_search_attributes.py @@ -22,7 +22,11 @@ wait_exponential, ) -from tracecat.config import TEMPORAL__API_KEY__ARN, TEMPORAL__CLUSTER_NAMESPACE +from tracecat.config import ( + TEMPORAL__API_KEY, + TEMPORAL__API_KEY__ARN, + TEMPORAL__CLUSTER_NAMESPACE, +) from tracecat.dsl.client import get_temporal_client from tracecat.logger import logger @@ -72,7 +76,7 @@ async def remove_temporal_search_attributes(): def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - if TEMPORAL__API_KEY__ARN or os.environ.get("TEMPORAL__API_KEY"): + if TEMPORAL__API_KEY__ARN or TEMPORAL__API_KEY: logger.info("Using Temporal cloud, skipping upgrade (add search attributes)") else: asyncio.run(add_temporal_search_attributes()) @@ -81,7 +85,7 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - if TEMPORAL__API_KEY__ARN or os.environ.get("TEMPORAL__API_KEY"): + if TEMPORAL__API_KEY__ARN or TEMPORAL__API_KEY: logger.info( "Using Temporal cloud, skipping downgrade (remove search attributes)" ) diff --git a/tracecat/config.py b/tracecat/config.py index 09504503f..3a96788a8 100644 --- a/tracecat/config.py +++ b/tracecat/config.py @@ -128,6 +128,7 @@ "TEMPORAL__CLUSTER_QUEUE", "tracecat-task-queue" ) TEMPORAL__API_KEY__ARN = os.environ.get("TEMPORAL__API_KEY__ARN") +TEMPORAL__API_KEY = os.environ.get("TEMPORAL__API_KEY") TEMPORAL__MTLS_ENABLED = os.environ.get("TEMPORAL__MTLS_ENABLED", "").lower() in ( "1", "true", diff --git a/tracecat/workflow/executions/service.py b/tracecat/workflow/executions/service.py index 3906fd9cd..26b942e3e 100644 --- a/tracecat/workflow/executions/service.py +++ b/tracecat/workflow/executions/service.py @@ -525,16 +525,22 @@ async def _dispatch_workflow( run_config=dsl.config, kwargs=kwargs, ) - search_attrs = [ - trigger_type.to_temporal_search_attr_pair(), - ] - if self.role.user_id is not None: - search_attrs.append( - SearchAttributePair( - key=SearchAttributeKey.for_keyword("TracecatTriggeredByUserId"), - value=str(self.role.user_id), + + if config.TEMPORAL__API_KEY__ARN or config.TEMPORAL__API_KEY: + self.logger.warning("Using Temporal cloud, skipping search attributes") + search_attrs = None + else: + pairs = [ + trigger_type.to_temporal_search_attr_pair(), + ] + if self.role.user_id is not None: + pairs.append( + SearchAttributePair( + key=SearchAttributeKey.for_keyword("TracecatTriggeredByUserId"), + value=str(self.role.user_id), + ) ) - ) + search_attrs = TypedSearchAttributes(search_attributes=pairs) try: result = await self._client.execute_workflow( DSLWorkflow.run, @@ -546,7 +552,7 @@ async def _dispatch_workflow( retry_policy=retry_policies["workflow:fail_fast"], # We don't currently differentiate between exec and run timeout as we fail fast for workflows execution_timeout=datetime.timedelta(seconds=dsl.config.timeout), - search_attributes=TypedSearchAttributes(search_attributes=search_attrs), + search_attributes=search_attrs, **kwargs, ) except WorkflowFailureError as e: diff --git a/tracecat/workflow/schedules/bridge.py b/tracecat/workflow/schedules/bridge.py index 8c3df58ea..c7c809c3e 100644 --- a/tracecat/workflow/schedules/bridge.py +++ b/tracecat/workflow/schedules/bridge.py @@ -8,10 +8,11 @@ from tracecat.dsl.client import get_temporal_client from tracecat.dsl.common import DSLRunArgs from tracecat.identifiers import ScheduleID, WorkflowID +from tracecat.logger import logger from tracecat.workflow.executions.enums import TriggerType from tracecat.workflow.schedules.models import ScheduleUpdate -search_attrs = TypedSearchAttributes( +SEARCH_ATTRS = TypedSearchAttributes( search_attributes=[TriggerType.SCHEDULED.to_temporal_search_attr_pair()] ) @@ -44,6 +45,13 @@ async def create_schedule( workflow_schedule_id = f"{workflow_id}:{schedule_id}" + if config.TEMPORAL__API_KEY__ARN or config.TEMPORAL__API_KEY: + logger.warning( + "Using Temporal cloud, skipping search attributes (add to schedule)" + ) + search_attrs = TypedSearchAttributes.empty + else: + search_attrs = SEARCH_ATTRS return await client.create_schedule( id=schedule_id, schedule=temporalio.client.Schedule( @@ -97,7 +105,12 @@ async def _update_schedule( if "status" in set_fields: state.paused = set_fields["status"] != "online" if isinstance(action, temporalio.client.ScheduleActionStartWorkflow): - action.typed_search_attributes = search_attrs + if config.TEMPORAL__API_KEY__ARN or config.TEMPORAL__API_KEY: + logger.warning( + "Using Temporal cloud, skipping search attributes (update schedule)" + ) + else: + action.typed_search_attributes = SEARCH_ATTRS if "inputs" in set_fields: action.args[0].dsl.trigger_inputs = set_fields["inputs"] # type: ignore else: