Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pulling multiple XCom values #45243

Open
Tracked by #44481
kaxil opened this issue Dec 27, 2024 · 1 comment
Open
Tracked by #44481

Support pulling multiple XCom values #45243

kaxil opened this issue Dec 27, 2024 · 1 comment
Assignees
Labels

Comments

@kaxil
Copy link
Member

kaxil commented Dec 27, 2024

Currently, we only support pulling single XCom value from ti.xcom_pull in the Task SDK.

Also port tests from

def test_xcom_pull(self, dag_maker):
"""Test xcom_pull, using different filtering methods."""
with dag_maker(dag_id="test_xcom") as dag:
task_1 = EmptyOperator(task_id="test_xcom_1")
task_2 = EmptyOperator(task_id="test_xcom_2")
dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
ti1 = dagrun.get_task_instance(task_1.task_id)
# Push a value
ti1.xcom_push(key="foo", value="bar")
# Push another value with the same key (but by a different task)
XCom.set(key="foo", value="baz", task_id=task_2.task_id, dag_id=dag.dag_id, run_id=dagrun.run_id)
# Pull with no arguments
result = ti1.xcom_pull()
assert result is None
# Pull the value pushed most recently by any task.
result = ti1.xcom_pull(key="foo")
assert result in "baz"
# Pull the value pushed by the first task
result = ti1.xcom_pull(task_ids="test_xcom_1", key="foo")
assert result == "bar"
# Pull the value pushed by the second task
result = ti1.xcom_pull(task_ids="test_xcom_2", key="foo")
assert result == "baz"
# Pull the values pushed by both tasks & Verify Order of task_ids pass & values returned
result = ti1.xcom_pull(task_ids=["test_xcom_1", "test_xcom_2"], key="foo")
assert result == ["bar", "baz"]
def test_xcom_pull_mapped(self, dag_maker, session):
with dag_maker(dag_id="test_xcom", session=session):
# Use the private _expand() method to avoid the empty kwargs check.
# We don't care about how the operator runs here, only its presence.
task_1 = EmptyOperator.partial(task_id="task_1")._expand(EXPAND_INPUT_EMPTY, strict=False)
EmptyOperator(task_id="task_2")
dagrun = dag_maker.create_dagrun(start_date=timezone.datetime(2016, 6, 1, 0, 0, 0))
ti_1_0 = dagrun.get_task_instance("task_1", session=session)
ti_1_0.map_index = 0
ti_1_1 = session.merge(TI(task_1, run_id=dagrun.run_id, map_index=1, state=ti_1_0.state))
session.flush()
ti_1_0.xcom_push(key=XCOM_RETURN_KEY, value="a", session=session)
ti_1_1.xcom_push(key=XCOM_RETURN_KEY, value="b", session=session)
ti_2 = dagrun.get_task_instance("task_2", session=session)
assert set(ti_2.xcom_pull(["task_1"], session=session)) == {"a", "b"} # Ordering not guaranteed.
assert ti_2.xcom_pull(["task_1"], map_indexes=0, session=session) == ["a"]
assert ti_2.xcom_pull(map_indexes=[0, 1], session=session) == ["a", "b"]
assert ti_2.xcom_pull("task_1", map_indexes=[1, 0], session=session) == ["b", "a"]
assert ti_2.xcom_pull(["task_1"], map_indexes=[0, 1], session=session) == ["a", "b"]
assert ti_2.xcom_pull("task_1", map_indexes=1, session=session) == "b"
assert list(ti_2.xcom_pull("task_1", session=session)) == ["a", "b"]
def test_xcom_pull_after_success(self, create_task_instance):
"""
tests xcom set/clear relative to a task in a 'success' rerun scenario
"""
key = "xcom_key"
value = "xcom_value"
ti = create_task_instance(
dag_id="test_xcom",
schedule="@monthly",
task_id="test_xcom",
pool="test_xcom",
serialized=True,
)
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
assert ti.xcom_pull(task_ids="test_xcom", key=key) == value
ti.run()
# Check that we do not clear Xcom until the task is certain to execute
assert ti.xcom_pull(task_ids="test_xcom", key=key) == value
# Xcom shouldn't be cleared if the task doesn't execute, even if dependencies are ignored
ti.run(ignore_all_deps=True, mark_success=True)
assert ti.xcom_pull(task_ids="test_xcom", key=key) == value
# Xcom IS finally cleared once task has executed
ti.run(ignore_all_deps=True)
assert ti.xcom_pull(task_ids="test_xcom", key=key) is None
def test_xcom_pull_after_deferral(self, create_task_instance, session):
"""
tests xcom will not clear before a task runs its next method after deferral.
"""
key = "xcom_key"
value = "xcom_value"
ti = create_task_instance(
dag_id="test_xcom",
schedule="@monthly",
task_id="test_xcom",
pool="test_xcom",
)
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
ti.next_method = "execute"
session.merge(ti)
session.commit()
ti.run(ignore_all_deps=True)
assert ti.xcom_pull(task_ids="test_xcom", key=key) == value
def test_xcom_pull_different_logical_date(self, create_task_instance):
"""
tests xcom fetch behavior with different logical dates, using
both xcom_pull with "include_prior_dates" and without
"""
key = "xcom_key"
value = "xcom_value"
ti = create_task_instance(
dag_id="test_xcom",
schedule="@monthly",
task_id="test_xcom",
pool="test_xcom",
)
exec_date = ti.dag_run.logical_date
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
assert ti.xcom_pull(task_ids="test_xcom", key=key) == value
ti.run()
exec_date += datetime.timedelta(days=1)
triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if AIRFLOW_V_3_0_PLUS else {}
dr = ti.task.dag.create_dagrun(
run_id="test2",
data_interval=(exec_date, exec_date),
state=None,
**triggered_by_kwargs,
)
ti = TI(task=ti.task, run_id=dr.run_id)
ti.run()
# We have set a new logical date (and did not pass in
# 'include_prior_dates'which means this task should now have a cleared
# xcom value
assert ti.xcom_pull(task_ids="test_xcom", key=key) is None
# We *should* get a value using 'include_prior_dates'
assert ti.xcom_pull(task_ids="test_xcom", key=key, include_prior_dates=True) == value
def test_xcom_pull_different_run_ids(self, create_task_instance):
"""
tests xcom fetch behavior w/different run ids
"""
key = "xcom_key"
task_id = "test_xcom"
diff_run_id = "diff_run_id"
same_run_id_value = "xcom_value_same_run_id"
diff_run_id_value = "xcom_value_different_run_id"
ti_same_run_id = create_task_instance(
dag_id="test_xcom",
task_id=task_id,
)
ti_same_run_id.run(mark_success=True)
ti_same_run_id.xcom_push(key=key, value=same_run_id_value)
ti_diff_run_id = create_task_instance(
dag_id="test_xcom",
task_id=task_id,
run_id=diff_run_id,
)
ti_diff_run_id.run(mark_success=True)
ti_diff_run_id.xcom_push(key=key, value=diff_run_id_value)
assert (
ti_same_run_id.xcom_pull(run_id=ti_same_run_id.dag_run.run_id, task_ids=task_id, key=key)
== same_run_id_value
)
assert (
ti_same_run_id.xcom_pull(run_id=ti_diff_run_id.dag_run.run_id, task_ids=task_id, key=key)
== diff_run_id_value
)
def test_xcom_push_flag(self, dag_maker):
"""
Tests the option for Operators to push XComs
"""
value = "hello"
task_id = "test_no_xcom_push"
with dag_maker(dag_id="test_xcom", serialized=True):
# nothing saved to XCom
task = PythonOperator(
task_id=task_id,
python_callable=lambda: value,
do_xcom_push=False,
)
ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
ti.task = task
ti.run()
assert ti.xcom_pull(task_ids=task_id, key=XCOM_RETURN_KEY) is None

@kaxil
Copy link
Member Author

kaxil commented Jan 6, 2025

Or we could simplify this on the Server side and loop through on the client side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Development

No branches or pull requests

1 participant