Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Add support for get_current_context in Task SDK #45486

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
from airflow.plugins_manager import integrate_macros_plugins
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetNameRef, AssetUniqueKey, AssetUriRef
from airflow.sdk.definitions.templater import SandboxedEnvironment
from airflow.sdk.execution_time.context import _CURRENT_CONTEXT
from airflow.sentry import Sentry
from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
Expand Down Expand Up @@ -142,7 +143,6 @@

TR = TaskReschedule

_CURRENT_CONTEXT: list[Context] = []
log = logging.getLogger(__name__)


Expand Down
25 changes: 17 additions & 8 deletions providers/src/airflow/providers/standard/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,10 @@
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
from airflow.models.variable import Variable
from airflow.operators.branch import BranchMixIn
from airflow.providers.standard.utils.python_virtualenv import prepare_virtualenv, write_python_script
from airflow.providers.standard.version_compat import (
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
)
from airflow.providers.standard.version_compat import AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
Expand Down Expand Up @@ -1122,7 +1118,7 @@ def execute(self, context: Context) -> Any:
return self.do_branch(context, super().execute(context))


def get_current_context() -> Context:
def get_current_context() -> Mapping[str, Any]:
"""
Retrieve the execution context dictionary without altering user method's signature.

Expand All @@ -1149,9 +1145,22 @@ def my_task():
Current context will only have value if this method was called after an operator
was starting to execute.
"""
if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import get_current_context

return get_current_context()
else:
return _get_current_context()


def _get_current_context() -> Mapping[str, Any]:
# Airflow 2.x
# TODO: To be removed when Airflow 2 support is dropped
from airflow.models.taskinstance import _CURRENT_CONTEXT

if not _CURRENT_CONTEXT:
raise AirflowException(
raise RuntimeError(
"Current context was requested but no context was found! "
"Are you running within an airflow task?"
"Are you running within an Airflow task?"
)
return _CURRENT_CONTEXT[-1]
6 changes: 3 additions & 3 deletions providers/tests/standard/operators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,7 +1069,7 @@ def f():
with pytest.raises(
AirflowException,
match="Current context was requested but no context was found! "
"Are you running within an airflow task?",
"Are you running within an Airflow task?",
):
self.run_as_task(f, return_ti=True, use_airflow_context=False)

Expand Down Expand Up @@ -1890,7 +1890,7 @@ def default_kwargs(*, python_version=DEFAULT_PYTHON_VERSION, **kwargs):

class TestCurrentContext:
def test_current_context_no_context_raise(self):
with pytest.raises(AirflowException):
with pytest.raises(RuntimeError):
get_current_context()

def test_current_context_roundtrip(self):
Expand All @@ -1904,7 +1904,7 @@ def test_context_removed_after_exit(self):

with set_current_context(example_context):
pass
with pytest.raises(AirflowException):
with pytest.raises(RuntimeError):
get_current_context()

def test_nested_context(self):
Expand Down
3 changes: 3 additions & 0 deletions task_sdk/src/airflow/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"TaskGroup",
"dag",
"Connection",
"get_current_context",
"__version__",
]

Expand All @@ -34,6 +35,7 @@
if TYPE_CHECKING:
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.contextmanager import get_current_context
from airflow.sdk.definitions.dag import DAG, dag
from airflow.sdk.definitions.edges import EdgeModifier, Label
from airflow.sdk.definitions.taskgroup import TaskGroup
Expand All @@ -47,6 +49,7 @@
"Label": ".definitions.edges",
"Connection": ".definitions.connection",
"Variable": ".definitions.variable",
"get_current_context": ".definitions.contextmanager",
}


Expand Down
46 changes: 42 additions & 4 deletions task_sdk/src/airflow/sdk/definitions/contextmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import sys
from collections import deque
from collections.abc import Mapping
from types import ModuleType
from typing import Any, Generic, TypeVar

Expand All @@ -27,10 +28,47 @@

T = TypeVar("T")

__all__ = [
"DagContext",
"TaskGroupContext",
]
__all__ = ["DagContext", "TaskGroupContext", "get_current_context"]

# This is a global variable that stores the current Task context.
# It is used to push the Context dictionary when Task starts execution
# and it is used to retrieve the current context in PythonOperator or Taskflow API via
# the `get_current_context` function.
_CURRENT_CONTEXT: list[Mapping[str, Any]] = []


def get_current_context() -> Mapping[str, Any]:
"""
Retrieve the execution context dictionary without altering user method's signature.

This is the simplest method of retrieving the execution context dictionary.

**Old style:**

.. code:: python

def my_task(**context):
ti = context["ti"]

**New style:**

.. code:: python

from airflow.providers.standard.operators.python import get_current_context


def my_task():
context = get_current_context()
ti = context["ti"]

Current context will only have value if this method was called after an operator
was starting to execute.
"""
if not _CURRENT_CONTEXT:
raise RuntimeError(
"Current context was requested but no context was found! Are you running within an Airflow task?"
)
return _CURRENT_CONTEXT[-1]


# In order to add a `@classproperty`-like thing we need to define a property on a metaclass.
Expand Down
27 changes: 25 additions & 2 deletions task_sdk/src/airflow/sdk/execution_time/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# under the License.
from __future__ import annotations

import contextlib
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any

import structlog

from airflow.sdk.definitions.contextmanager import _CURRENT_CONTEXT
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.types import NOTSET

Expand All @@ -28,6 +31,8 @@
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.execution_time.comms import ConnectionResult, VariableResult

log = structlog.get_logger(logger_name="task")


def _convert_connection_result_conn(conn_result: ConnectionResult) -> Connection:
from airflow.sdk.definitions.connection import Connection
Expand Down Expand Up @@ -55,7 +60,6 @@ def _get_connection(conn_id: str) -> Connection:
from airflow.sdk.execution_time.comms import ErrorResponse, GetConnection
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetConnection(conn_id=conn_id))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
Expand All @@ -75,7 +79,6 @@ def _get_variable(key: str, deserialize_json: bool) -> Variable:
from airflow.sdk.execution_time.comms import ErrorResponse, GetVariable
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(log=log, msg=GetVariable(key=key))
msg = SUPERVISOR_COMMS.get_message()
if isinstance(msg, ErrorResponse):
Expand Down Expand Up @@ -157,3 +160,23 @@ def __eq__(self, other: object) -> bool:
if not isinstance(other, MacrosAccessor):
return False
return True


@contextlib.contextmanager
def set_current_context(context: Mapping[str, Any]) -> Generator[Mapping[str, Any], None, None]:
"""
Set the current execution context to the provided context object.

This method should be called once per Task execution, before calling operator.execute.
"""
_CURRENT_CONTEXT.append(context)
try:
yield context
finally:
expected_state = _CURRENT_CONTEXT.pop()
if expected_state != context:
log.warning(
"Current context is not equal to the state at context stack.",
expected=context,
got=expected_state,
)
65 changes: 39 additions & 26 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,21 @@
ToTask,
XComResult,
)
from airflow.sdk.execution_time.context import ConnectionAccessor, MacrosAccessor, VariableAccessor
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
MacrosAccessor,
VariableAccessor,
set_current_context,
)

if TYPE_CHECKING:
import jinja2
from structlog.typing import FilteringBoundLogger as Logger


# TODO: Move this entire class into a separate file:
# `airflow/sdk/execution_time/task_instance.py`
# or `airflow/sdk/execution_time/runtime_ti.py`
class RuntimeTaskInstance(TaskInstance):
model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -426,37 +434,18 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# TODO: Get a real context object
ti.task = ti.task.prepare_for_execution()
context = ti.get_template_context()
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
with set_current_context(context):
jinja_env = ti.task.dag.get_template_env()
ti.task = ti.render_templates(context=context, jinja_env=jinja_env)
result = _execute_task(context, ti.task)

_push_xcom_if_needed(result, ti)

# TODO: Get things from _execute_task_with_callbacks
# - Clearing XCom
# - Setting Current Context (set_current_context)
# - Render Templates
# - Update RTIF
# - Pre Execute
# etc

result = None
if ti.task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout

timeout_seconds = ti.task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = ti.task.execute(context) # type: ignore[attr-defined]
except AirflowTaskTimeout:
# TODO: handle on kill callback here
raise
else:
result = ti.task.execute(context) # type: ignore[attr-defined]

_push_xcom_if_needed(result, ti)
msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
Expand Down Expand Up @@ -524,6 +513,30 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _execute_task(context: Mapping[str, Any], task: BaseOperator):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
from airflow.exceptions import AirflowTaskTimeout

if task.execution_timeout:
# TODO: handle timeout in case of deferral
from airflow.utils.timeout import timeout

timeout_seconds = task.execution_timeout.total_seconds()
try:
# It's possible we're already timed out, so fast-fail if true
if timeout_seconds <= 0:
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = task.execute(context) # type: ignore[attr-defined]
except AirflowTaskTimeout:
# TODO: handle on kill callback here
raise
else:
result = task.execute(context) # type: ignore[attr-defined]
return result


def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
"""Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
if ti.task.do_xcom_push:
Expand Down
39 changes: 39 additions & 0 deletions task_sdk/tests/defintions/test_contextmanager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import pytest

from airflow.sdk import get_current_context


class TestCurrentContext:
def test_current_context_no_context_raise(self):
with pytest.raises(RuntimeError):
get_current_context()

def test_get_current_context_with_context(self, monkeypatch):
mock_context = {"ti": "task_instance", "key": "value"}
monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT", [mock_context])
result = get_current_context()
assert result == mock_context

def test_get_current_context_without_context(self, monkeypatch):
monkeypatch.setattr("airflow.sdk.definitions.contextmanager._CURRENT_CONTEXT", [])
with pytest.raises(RuntimeError, match="Current context was requested but no context was found!"):
get_current_context()
Loading