Skip to content

Commit

Permalink
Airflow Operator now supports extra links (armadaproject#329) (armada…
Browse files Browse the repository at this point in the history
…project#4134)

Co-authored-by: Martynas Asipauskas <[email protected]>
  • Loading branch information
masipauskas and Martynas Asipauskas authored Jan 13, 2025
1 parent 33f3a27 commit 3b0161b
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 7 deletions.
70 changes: 67 additions & 3 deletions docs/python_airflow_operator.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This class provides integration with Airflow and Armada
## armada.operators.armada module


### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, dry_run=False, reattach_policy=None, \*\*kwargs)
### _class_ armada.operators.armada.ArmadaOperator(name, channel_args, armada_queue, job_request, job_set_prefix='', lookout_url_template=None, poll_interval=30, container_logs=None, k8s_token_retriever=None, deferrable=False, job_acknowledgement_timeout=300, dry_run=False, reattach_policy=None, extra_links=None, \*\*kwargs)
Bases: `BaseOperator`, `LoggingMixin`

An Airflow operator that manages Job submission to Armada.
Expand Down Expand Up @@ -63,6 +63,9 @@ and handles job cancellation if the Airflow task is killed.
* **reattach_policy** (*Optional**[**str**] **| **Callable**[**[**JobState**, **str**]**, **bool**]*) –


* **extra_links** (*Optional**[**Dict**[**str**, **str**]**]*) –



#### execute(context)
Submits the job to Armada and polls for completion.
Expand Down Expand Up @@ -97,10 +100,33 @@ operator needs to be cleaned up, or it will leave ghost processes behind.



#### operator_extra_links(_: Collection[BaseOperatorLink_ _ = (LookoutLink(),_ )

#### _property_ pod_manager(_: KubernetesPodLogManage_ )

#### render_extra_links_urls(context, jinja_env=None)
Template all URLs listed in self.extra_links.
This pushes all URL values to xcom for values to be picked up by UI.

Args:

context (Context): The execution context provided by Airflow.


* **Parameters**


* **context** (*Context*) – Airflow Context dict wi1th values to apply on content


* **jinja_env** (*Environment** | **None*) – jinja’s environment to use for rendering.



* **Return type**

None



#### render_template_fields(context, jinja_env=None)
Template all attributes listed in self.template_fields.
This mutates the attributes in-place and is irreversible.
Expand Down Expand Up @@ -173,7 +199,45 @@ acknowledged by Armada.
:param reattach_policy: Operator reattach policy to use (defaults to: never)
:type reattach_policy: Optional[str] | Callable[[JobState, str], bool]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
:param extra_links: Extra links to be shown in addition to Lookout URL.
:type extra_links: Optional[Dict[str, str]]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.


### _class_ armada.operators.armada.DynamicLink(name)
Bases: `BaseOperatorLink`, `LoggingMixin`


* **Parameters**

**name** (*str*) –



#### get_link(operator, \*, ti_key)
Link to external system.

Note: The old signature of this function was `(self, operator, dttm: datetime)`. That is still
supported at runtime but is deprecated.


* **Parameters**


* **operator** (*BaseOperator*) – The Airflow operator object this link is associated to.


* **ti_key** (*TaskInstanceKey*) – TaskInstance ID to return link for.



* **Returns**

link to external system



#### name(_: st_ )

### _class_ armada.operators.armada.LookoutLink()
Bases: `BaseOperatorLink`
Expand Down
5 changes: 4 additions & 1 deletion third_party/airflow/armada/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def get_provider_info():
"package-name": "armada-airflow",
"name": "Armada Airflow Operator",
"description": "Armada Airflow Operator.",
"extra-links": ["armada.operators.armada.LookoutLink"],
"extra-links": [
"armada.operators.armada.LookoutLink",
"armada.operators.armada.DynamicLink",
],
"versions": ["1.0.0"],
}
65 changes: 62 additions & 3 deletions third_party/airflow/armada/operators/armada.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import attrs
import dataclasses
import datetime
import os
Expand All @@ -25,6 +26,8 @@

import jinja2
import tenacity
import re

from airflow.configuration import conf
from airflow.exceptions import AirflowFailException
from airflow.models import BaseOperator, BaseOperatorLink, XCom
Expand Down Expand Up @@ -53,13 +56,22 @@ class LookoutLink(BaseOperatorLink):
name = "Lookout"

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
task_state = XCom.get_value(ti_key=ti_key)
task_state = XCom.get_value(ti_key=ti_key, key="job_context")
if not task_state:
return ""

return task_state.get("armada_lookout_url", "")


@attrs.define(init=True)
class DynamicLink(BaseOperatorLink, LoggingMixin):
name: str

def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
url = XCom.get_value(ti_key=ti_key, key=f"armada_{self.name.lower()}_url")
return url


class ArmadaOperator(BaseOperator, LoggingMixin):
"""
An Airflow operator that manages Job submission to Armada.
Expand All @@ -68,8 +80,6 @@ class ArmadaOperator(BaseOperator, LoggingMixin):
and handles job cancellation if the Airflow task is killed.
"""

operator_extra_links = (LookoutLink(),)

template_fields: Sequence[str] = ("job_request", "job_set_prefix")
template_fields_renderers: Dict[str, str] = {"job_request": "py"}

Expand Down Expand Up @@ -108,6 +118,9 @@ class ArmadaOperator(BaseOperator, LoggingMixin):
:param reattach_policy: Operator reattach policy to use (defaults to: never)
:type reattach_policy: Optional[str] | Callable[[JobState, str], bool]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
:param extra_links: Extra links to be shown in addition to Lookout URL.
:type extra_links: Optional[Dict[str, str]]
:param kwargs: Additional keyword arguments to pass to the BaseOperator.
"""

def __init__(
Expand Down Expand Up @@ -136,6 +149,7 @@ def __init__(
"armada_operator", "default_dry_run", fallback=False
),
reattach_policy: Optional[str] | Callable[[JobState, str], bool] = None,
extra_links: Optional[Dict[str, str]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down Expand Up @@ -176,6 +190,15 @@ def __init__(
"logs from Kubernetes"
)

self.extra_links = extra_links or {}
operator_links = []

if self.lookout_url_template:
operator_links.append(LookoutLink())

operator_links.extend([DynamicLink(name) for name in self.extra_links])
self.operator_extra_links = operator_links

@log_exceptions
def execute(self, context) -> None:
"""
Expand Down Expand Up @@ -264,6 +287,9 @@ def render_template_fields(
super().render_template_fields(context, jinja_env)
self._xcom_push(context, key="job_request", value=self.job_request)

# We should render extra links here.
self.render_extra_links_urls(context, jinja_env)

self.job_request = ParseDict(self.job_request, JobSubmitRequestItem())
self._annotate_job_request(context, self.job_request)

Expand All @@ -277,6 +303,35 @@ def render_template_fields(
)
self.container_logs = None

def render_extra_links_urls(
self, context: Context, jinja_env: Optional[jinja2.Environment] = None
) -> None:
"""
Template all URLs listed in self.extra_links.
This pushes all URL values to xcom for values to be picked up by UI.
Args:
context (Context): The execution context provided by Airflow.
:param context: Airflow Context dict wi1th values to apply on content
:param jinja_env: jinja’s environment to use for rendering.
"""
if jinja_env is None:
jinja_env = jinja2.Environment()

for name, url in self.extra_links.items():
if isinstance(url, re.Pattern):
self.log.warning(
f"Skipping link {name} because the URL appears is a regex: {url}"
)
continue
try:
rendered_url = jinja_env.from_string(url).render(context)
self._xcom_push(
context, key=f"armada_{name.lower()}_url", value=rendered_url
)
except jinja2.TemplateError as e:
self.log.error(f"Error rendering template for {name} ({url}): {e}")

def on_kill(self) -> None:
if self.job_context is not None:
self.log.info(
Expand Down Expand Up @@ -471,6 +526,10 @@ def _annotate_job_request(self, context, request: JobSubmitRequestItem):
request.annotations[annotation_key_prefix + "taskId"] = task_id
request.annotations[annotation_key_prefix + "taskRunId"] = run_id
request.annotations[annotation_key_prefix + "dagId"] = dag_id
request.annotations[annotation_key_prefix + "jobSet"] = (
f"{self.job_set_prefix}{run_id}"
)

request.annotations[annotation_key_prefix + "externalJobUri"] = (
external_job_uri(context)
)

0 comments on commit 3b0161b

Please sign in to comment.