From c7a2a4e2a3cb48c79aadab8a59c7fac0cf1f721b Mon Sep 17 00:00:00 2001 From: Martynas Asipauskas Date: Tue, 26 Nov 2024 19:34:44 +0000 Subject: [PATCH] Re-attach logic - final fixes (#4064) --- docs/python_airflow_operator.md | 7 +++- .../airflow/armada/operators/armada.py | 36 ++++++++++++------- third_party/airflow/pyproject.toml | 4 +-- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/docs/python_airflow_operator.md b/docs/python_airflow_operator.md index 486d47cde1c..8037c0e8ace 100644 --- a/docs/python_airflow_operator.md +++ b/docs/python_airflow_operator.md @@ -12,7 +12,7 @@ This class provides integration with Airflow and Armada ## armada.operators.armada module -### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, dry_run=False, \*\*kwargs) +### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, dry_run=False, reattach_policy=None, \*\*kwargs) Bases: `BaseOperator`, `LoggingMixin` An Airflow operator that manages Job submission to Armada. @@ -60,6 +60,9 @@ and handles job cancellation if the Airflow task is killed. * **dry_run** (*bool*) – + * **reattach_policy** (*Optional**[**str**] **| **Callable**[**[**JobState**, **str**]**, **bool**]*) – + + #### execute(context) Submits the job to Armada and polls for completion. @@ -167,6 +170,8 @@ acknowledged by Armada. :type job_acknowledgement_timeout: int :param dry_run: Run Operator in dry-run mode - render Armada request and terminate. :type dry_run: bool +:param reattach_policy: Operator reattach policy to use (defaults to: never) +:type reattach_policy: Optional[str] | Callable[[JobState, str], bool] :param kwargs: Additional keyword arguments to pass to the BaseOperator. diff --git a/third_party/airflow/armada/operators/armada.py b/third_party/airflow/armada/operators/armada.py index 158a80cc479..0c683928fb1 100644 --- a/third_party/airflow/armada/operators/armada.py +++ b/third_party/airflow/armada/operators/armada.py @@ -105,8 +105,8 @@ class ArmadaOperator(BaseOperator, LoggingMixin): :type job_acknowledgement_timeout: int :param dry_run: Run Operator in dry-run mode - render Armada request and terminate. :type dry_run: bool -:param reattach_policy: Operator reattach policy to use (defaults to: always) -:type reattach_policy: Optional[str] +:param reattach_policy: Operator reattach policy to use (defaults to: never) +:type reattach_policy: Optional[str] | Callable[[JobState, str], bool] :param kwargs: Additional keyword arguments to pass to the BaseOperator. """ @@ -135,7 +135,7 @@ def __init__( dry_run: bool = conf.getboolean( "armada_operator", "default_dry_run", fallback=False ), - reattach_policy: Optional[str] = None, + reattach_policy: Optional[str] | Callable[[JobState, str], bool] = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -154,14 +154,21 @@ def __init__( self.dry_run = dry_run self.job_context = None - configured_reattach_policy: str = resolve_parameter_value( - "reattach_policy", reattach_policy, kwargs, "never" - ) - self.log.info( - f"Configured reattach policy to: '{configured_reattach_policy}'," - f" max retries: {self.retries}" - ) - self.reattach_policy = policy(configured_reattach_policy) + if reattach_policy is callable(reattach_policy): + self.log.info( + f"Configured reattach policy with callable'," + f" max retries: {self.retries}" + ) + self.reattach_policy = reattach_policy + else: + configured_reattach_policy: str = resolve_parameter_value( + "reattach_policy", reattach_policy, kwargs, "never" + ) + self.log.info( + f"Configured reattach policy to: '{configured_reattach_policy}'," + f" max retries: {self.retries}" + ) + self.reattach_policy = policy(configured_reattach_policy) if self.container_logs and self.k8s_token_retriever is None: self.log.warning( @@ -342,8 +349,11 @@ def _try_reattach_to_running_job( self, context: Context ) -> Optional[RunningJobContext]: # On first try we intentionally do not re-attach. - self.log.info(context) - if context["ti"].try_number == 1: + new_run = ( + context["ti"].max_tries - context["ti"].try_number + 1 + == context["ti"].task.retries + ) + if new_run: return None expected_job_uri = external_job_uri(context) diff --git a/third_party/airflow/pyproject.toml b/third_party/airflow/pyproject.toml index 67c7b91f678..3e278fade54 100644 --- a/third_party/airflow/pyproject.toml +++ b/third_party/airflow/pyproject.toml @@ -4,13 +4,13 @@ build-backend = "setuptools.build_meta" [project] name = "armada_airflow" -version = "1.0.10" +version = "1.0.11" description = "Armada Airflow Operator" readme='README.md' authors = [{name = "Armada-GROSS", email = "armada@armadaproject.io"}] license = { text = "Apache Software License" } dependencies=[ - 'armada-client>=0.4.7', + 'armada-client>=0.4.8', 'apache-airflow>=2.6.3', 'types-protobuf==4.24.0.1', 'kubernetes>=23.6.0',