From 8f3657f15e87384f3d5475286b73ccdfd0c007c9 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 11:10:01 +0100 Subject: [PATCH 01/20] refactor: Refactored the GenericTransfer operator to support paginated reads (in deferred mode) and introduce a SQLExecuteQueryTrigger --- .../providers/common/sql/provider.yaml | 5 + .../providers/common/sql/triggers/__init__.py | 16 ++ .../providers/common/sql/triggers/sql.py | 97 +++++++++++ .../standard/operators/generic_transfer.py | 158 +++++++++++++++--- 4 files changed, 254 insertions(+), 22 deletions(-) create mode 100644 providers/src/airflow/providers/common/sql/triggers/__init__.py create mode 100644 providers/src/airflow/providers/common/sql/triggers/sql.py diff --git a/providers/src/airflow/providers/common/sql/provider.yaml b/providers/src/airflow/providers/common/sql/provider.yaml index 530cc35188265..d778d73c57cc2 100644 --- a/providers/src/airflow/providers/common/sql/provider.yaml +++ b/providers/src/airflow/providers/common/sql/provider.yaml @@ -103,6 +103,11 @@ hooks: - airflow.providers.common.sql.hooks.handlers - airflow.providers.common.sql.hooks.sql +triggers: + - integration-name: Common SQL + python-modules: + - airflow.providers.common.sql.triggers.sql + sensors: - integration-name: Common SQL python-modules: diff --git a/providers/src/airflow/providers/common/sql/triggers/__init__.py b/providers/src/airflow/providers/common/sql/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/src/airflow/providers/common/sql/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.py b/providers/src/airflow/providers/common/sql/triggers/sql.py new file mode 100644 index 0000000000000..627c9dedede8e --- /dev/null +++ b/providers/src/airflow/providers/common/sql/triggers/sql.py @@ -0,0 +1,97 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + +if TYPE_CHECKING: + from typing import ( + Any, + AsyncIterator, + ) + + +class SQLExecuteQueryTrigger(BaseTrigger): + """ + A trigger that executes SQL code in async mode. + + :param sql: the sql statement to be executed (str) or a list of sql statements to execute + :param conn_id: the connection ID used to connect to the database + :param hook_params: hook parameters + """ + def __init__( + self, + sql: str | list[str], + conn_id: str | None = None, + hook_params: dict | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.sql = sql + self.conn_id = conn_id + self.hook_params = hook_params + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the SQLExecuteQueryTrigger arguments and classpath.""" + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}", + { + "sql": self.sql, + "conn_id": self.conn_id, + "hook_params": self.hook_params, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + try: + hook = BaseHook.get_hook(self.conn_id, hook_params=self.hook_params) + + if not isinstance(hook, DbApiHook): + raise AirflowException( + f"You are trying to use `common-sql` with {hook.__class__.__name__}," + " but its provider does not support it. Please upgrade the provider to a version that" + " supports `common-sql`. The hook class should be a subclass of" + " `airflow.providers.common.sql.hooks.sql.DbApiHook`." + f" Got {hook.__class__.__name__} Hook with class hierarchy: {hook.__class__.mro()}" + ) + + self.log.info("Extracting data from %s", self.conn_id) + self.log.info("Executing: \n %s", self.sql) + + get_records = getattr(hook, "get_records", None) + + if not callable(get_records): + raise RuntimeError( + f"Hook for connection {self.source_conn_id!r} " + f"({type(hook).__name__}) has no `get_records` method" + ) + else: + self.log.info("Reading records from %s", self.conn_id) + results = get_records(self.sql) + self.log.info("Reading records from %s done!", self.conn_id) + + self.log.debug("results: %s", results) + yield TriggerEvent({"status": "success", "results": results}) + except Exception as e: + self.log.exception("An error occurred: %s", e) + yield TriggerEvent({"status": "failure", "message": str(e)}) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 1cb3448f8a578..ea2de13e04c1d 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -18,10 +18,13 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +import jinja2 +from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator +from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger if TYPE_CHECKING: from airflow.utils.context import Context @@ -40,10 +43,13 @@ class GenericTransfer(BaseOperator): :param sql: SQL query to execute against the source database. (templated) :param destination_table: target table. (templated) :param source_conn_id: source connection. (templated) + :param source_hook_params: source hook parameters. :param destination_conn_id: destination connection. (templated) + :param destination_hook_params: destination hook parameters. :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) :param insert_args: extra params for `insert_rows` method. + :param chunk_size: number of records to be read in paginated mode (optional). """ template_fields: Sequence[str] = ( @@ -72,6 +78,7 @@ def __init__( destination_hook_params: dict | None = None, preoperator: str | list[str] | None = None, insert_args: dict | None = None, + chunk_size: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -83,6 +90,10 @@ def __init__( self.destination_hook_params = destination_hook_params self.preoperator = preoperator self.insert_args = insert_args or {} + self.chunk_size = chunk_size + self._paginated_sql_statement_format = kwargs.get( + "paginated_sql_statement_format", "{} LIMIT {} OFFSET {}" + ) @classmethod def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook: @@ -96,22 +107,29 @@ def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook: connection = BaseHook.get_connection(conn_id) return connection.get_hook(hook_params=hook_params) - def execute(self, context: Context): - source_hook = self.get_hook(conn_id=self.source_conn_id, hook_params=self.source_hook_params) - destination_hook = self.get_hook( - conn_id=self.destination_conn_id, hook_params=self.destination_hook_params + def get_paginated_sql(self, offset: int) -> str: + """Format the paginated SQL statement using the current format.""" + return self._paginated_sql_statement_format.format( + self.sql, self.chunk_size, offset ) - self.log.info("Extracting data from %s", self.source_conn_id) - self.log.info("Executing: \n %s", self.sql) - get_records = getattr(source_hook, "get_records", None) - if not callable(get_records): - raise RuntimeError( - f"Hook for connection {self.source_conn_id!r} " - f"({type(source_hook).__name__}) has no `get_records` method" - ) - else: - results = get_records(self.sql) + def render_template_fields( + self, + context: Context, + jinja_env: jinja2.Environment | None = None, + ) -> None: + super().render_template_fields(context=context, jinja_env=jinja_env) + + # Make sure string are converted to integers + if isinstance(self.chunk_size, str): + self.chunk_size = int(self.chunk_size) + if isinstance(self.insert_args.get("commit_every"), str): + self.insert_args["commit_every"] = int(self.insert_args.get("commit_every")) + + def paginated_execute(self, context: Context): + destination_hook = BaseHook.get_hook( + self.destination_conn_id, hook_params=self.destination_hook_params + ) if self.preoperator: run = getattr(destination_hook, "run", None) @@ -124,11 +142,107 @@ def execute(self, context: Context): self.log.info(self.preoperator) run(self.preoperator) - insert_rows = getattr(destination_hook, "insert_rows", None) - if not callable(insert_rows): - raise RuntimeError( - f"Hook for connection {self.destination_conn_id!r} " - f"({type(destination_hook).__name__}) has no `insert_rows` method" + if isinstance(self.insert_args.get("commit_every"), str): + self.insert_args["commit_every"] = int(self.insert_args.get("commit_every")) + + if self.chunk_size and isinstance(self.sql, str): + self.defer( + trigger=SQLExecuteQueryTrigger( + conn_id=self.source_conn_id, + hook_params=self.source_hook_params, + sql=self.get_paginated_sql(0), + ), + method_name=self.execute_complete.__name__, + ) + else: + source_hook = BaseHook.get_hook( + self.source_conn_id, hook_params=self.source_hook_params + ) + + self.log.info("Extracting data from %s", self.source_conn_id) + self.log.info("Executing: \n %s", self.sql) + + get_records = getattr(source_hook, "get_records", None) + + if not callable(get_records): + raise RuntimeError( + f"Hook for connection {self.source_conn_id!r} " + f"({type(source_hook).__name__}) has no `get_records` method" + ) + + results = get_records(self.sql) + insert_rows = getattr(destination_hook, "insert_rows", None) + + if not callable(insert_rows): + raise RuntimeError( + f"Hook for connection {self.destination_conn_id!r} " + f"({type(destination_hook).__name__}) has no `insert_rows` method" + ) + + self.log.info("Inserting rows into %s", self.destination_conn_id) + insert_rows(table=self.destination_table, rows=results, **self.insert_args) + + def execute_complete( + self, + context: Context, + event: dict[Any, Any] | None = None, + ) -> Any: + if event: + if event.get("status") == "failure": + raise AirflowException(event.get("message")) + + destination_hook = BaseHook.get_hook( + self.destination_conn_id, hook_params=self.destination_hook_params ) - self.log.info("Inserting rows into %s", self.destination_conn_id) - insert_rows(table=self.destination_table, rows=results, **self.insert_args) + + results = event.get("results") + + if results: + insert_rows = getattr(destination_hook, "insert_rows", None) + + if not callable(insert_rows): + raise RuntimeError( + f"Hook for connection {self.destination_conn_id!r} " + f"({type(destination_hook).__name__}) has no `insert_rows` method" + ) + + map_index = context["ti"].map_index + offset = ( + context["ti"].xcom_pull( + key="offset", + task_ids=self.task_id, + dag_id=self.dag_id, + map_indexes=map_index, + default=0, + ) + + self.chunk_size + ) + + self.log.info("Offset increased to %d", offset) + self.xcom_push(context=context, key="offset", value=offset) + + self.log.info( + "Inserting %d rows into %s", len(results), self.destination_conn_id + ) + insert_rows( + table=self.destination_table, rows=results, **self.insert_args + ) + self.log.info( + "Inserting %d rows into %s done!", + len(results), + self.destination_conn_id, + ) + + self.defer( + trigger=SQLExecuteQueryTrigger( + conn_id=self.source_conn_id, + hook_params=self.source_hook_params, + sql=self.get_paginated_sql(offset), + ), + method_name=self.execute_complete.__name__, + ) + else: + self.log.info( + "No more rows to fetch into %s; ending transfer.", + self.destination_table, + ) From 9b36a148da71b5bf52242a4350bfff4b77c281c7 Mon Sep 17 00:00:00 2001 From: dabla Date: Tue, 10 Dec 2024 11:34:35 +0100 Subject: [PATCH 02/20] refactor: updated provider dependencies --- generated/provider_dependencies.json | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index f58c010b3f58e..96473b369f8f4 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -1294,7 +1294,9 @@ ], "devel-deps": [], "plugins": [], - "cross-providers-deps": [], + "cross-providers-deps": [ + "common.sql" + ], "excluded-python-versions": [], "state": "ready" }, From eedcd528b37f4fff59fa59365300af00546b895f Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 12:08:50 +0100 Subject: [PATCH 03/20] refactor: Added TestSQLExecuteQueryTrigger and moved test code which allows you to run deferrable operators in test common test utils --- .../tests/common/sql/triggers/__init__.py | 16 +++++ .../tests/common/sql/triggers/test_sql.py | 48 +++++++++++++ providers/tests/microsoft/azure/base.py | 52 +------------- .../microsoft/azure/operators/test_msgraph.py | 13 ++-- .../microsoft/azure/operators/test_powerbi.py | 3 +- .../microsoft/azure/sensors/test_msgraph.py | 5 +- .../microsoft/azure/triggers/test_msgraph.py | 9 +-- providers/tests/microsoft/conftest.py | 53 +------------- tests_common/test_utils/mock_context.py | 67 +++++++++++++++++ .../test_utils/operators/run_deferable.py | 71 +++++++++++++++++++ 10 files changed, 224 insertions(+), 113 deletions(-) create mode 100644 providers/tests/common/sql/triggers/__init__.py create mode 100644 providers/tests/common/sql/triggers/test_sql.py create mode 100644 tests_common/test_utils/mock_context.py create mode 100644 tests_common/test_utils/operators/run_deferable.py diff --git a/providers/tests/common/sql/triggers/__init__.py b/providers/tests/common/sql/triggers/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/providers/tests/common/sql/triggers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/providers/tests/common/sql/triggers/test_sql.py b/providers/tests/common/sql/triggers/test_sql.py new file mode 100644 index 0000000000000..dfa2fd5d5bcc3 --- /dev/null +++ b/providers/tests/common/sql/triggers/test_sql.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from unittest import mock + +import pytest +from airflow.providers.common.sql.hooks.sql import DbApiHook +from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger +from airflow.triggers.base import TriggerEvent + +from providers.tests.microsoft.azure.base import Base +from tests_common.test_utils.operators.run_deferable import run_trigger +from tests_common.test_utils.version_compat import AIRFLOW_V_2_9_PLUS + +pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.8.0+ only") + + +class TestSQLExecuteQueryTrigger(Base): + @mock.patch("airflow.providers.common.sql.sensors.sql.BaseHook") + def test_run(self, mock_hook): + data = [(1, "Alice"), (2, "Bob")] + mock_hook.get_connection.return_value.get_hook.return_value = mock.MagicMock(spec=DbApiHook) + mock_get_records = mock_hook.get_connection.return_value.get_hook.return_value.get_records + mock_get_records.return_value = data + + trigger = SQLExecuteQueryTrigger(sql="SELECT * FROM users;", conn_id="test_conn_id") + actual = run_trigger(trigger) + + assert len(actual) == 1 + assert isinstance(actual[0], TriggerEvent) + assert actual[0].payload["status"] == "success" + assert actual[0].payload["results"] == json.dumps(data) diff --git a/providers/tests/microsoft/azure/base.py b/providers/tests/microsoft/azure/base.py index 600e4ce488e08..7f37b1bb524aa 100644 --- a/providers/tests/microsoft/azure/base.py +++ b/providers/tests/microsoft/azure/base.py @@ -16,22 +16,13 @@ # under the License. from __future__ import annotations -import asyncio from contextlib import contextmanager -from copy import deepcopy -from typing import TYPE_CHECKING, Any from unittest.mock import patch -from kiota_http.httpx_request_adapter import HttpxRequestAdapter - -from airflow.exceptions import TaskDeferred from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook +from kiota_http.httpx_request_adapter import HttpxRequestAdapter -from providers.tests.microsoft.conftest import get_airflow_connection, mock_context - -if TYPE_CHECKING: - from airflow.models import Operator - from airflow.triggers.base import BaseTrigger, TriggerEvent +from providers.tests.microsoft.conftest import get_airflow_connection class Base: @@ -49,42 +40,3 @@ def patch_hook_and_request_adapter(self, response): else: mock_get_http_response.return_value = response yield - - @staticmethod - async def _run_tigger(trigger: BaseTrigger) -> list[TriggerEvent]: - events = [] - async for event in trigger.run(): - events.append(event) - return events - - def run_trigger(self, trigger: BaseTrigger) -> list[TriggerEvent]: - return asyncio.run(self._run_tigger(trigger)) - - def execute_operator(self, operator: Operator) -> tuple[Any, Any]: - context = mock_context(task=operator) - return asyncio.run(self.deferrable_operator(context, operator)) - - async def deferrable_operator(self, context, operator): - result = None - triggered_events = [] - try: - operator.render_template_fields(context=context) - result = operator.execute(context=context) - except TaskDeferred as deferred: - task = deferred - - while task: - events = await self._run_tigger(task.trigger) - - if not events: - break - - triggered_events.extend(deepcopy(events)) - - try: - method = getattr(operator, task.method_name) - result = method(context=context, event=next(iter(events)).payload) - task = None - except TaskDeferred as exception: - task = exception - return result, triggered_events diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py b/providers/tests/microsoft/azure/operators/test_msgraph.py index b722c4c9f0e85..13770cf1d034d 100644 --- a/providers/tests/microsoft/azure/operators/test_msgraph.py +++ b/providers/tests/microsoft/azure/operators/test_msgraph.py @@ -35,6 +35,7 @@ mock_json_response, mock_response, ) +from tests_common.test_utils.operators.run_deferable import execute_operator from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS if TYPE_CHECKING: @@ -56,7 +57,7 @@ def test_execute(self): result_processor=lambda context, result: result.get("value"), ) - results, events = self.execute_operator(operator) + results, events = execute_operator(operator) assert len(results) == 30 assert results == users.get("value") + next_users.get("value") @@ -84,7 +85,7 @@ def test_execute_when_do_xcom_push_is_false(self): do_xcom_push=False, ) - results, events = self.execute_operator(operator) + results, events = execute_operator(operator) assert isinstance(results, dict) assert len(events) == 1 @@ -104,7 +105,7 @@ def test_execute_when_an_exception_occurs(self): ) with pytest.raises(AirflowException): - self.execute_operator(operator) + execute_operator(operator) @pytest.mark.db_test def test_execute_when_an_exception_occurs_on_custom_event_handler(self): @@ -124,7 +125,7 @@ def custom_event_handler(context: Context, event: dict[Any, Any] | None = None): event_handler=custom_event_handler, ) - results, events = self.execute_operator(operator) + results, events = execute_operator(operator) assert not results assert len(events) == 1 @@ -148,7 +149,7 @@ def test_execute_when_response_is_bytes(self): path_parameters={"drive_id": drive_id}, ) - results, events = self.execute_operator(operator) + results, events = execute_operator(operator) assert operator.path_parameters == {"drive_id": drive_id} assert results == base64_encoded_content @@ -175,7 +176,7 @@ def test_execute_with_lambda_parameter_when_response_is_bytes(self): path_parameters=lambda context, jinja_env: {"drive_id": drive_id}, ) - results, events = self.execute_operator(operator) + results, events = execute_operator(operator) assert operator.path_parameters == {"drive_id": drive_id} assert results == base64_encoded_content diff --git a/providers/tests/microsoft/azure/operators/test_powerbi.py b/providers/tests/microsoft/azure/operators/test_powerbi.py index a115b4c52dc50..bf7c03b8c9ee2 100644 --- a/providers/tests/microsoft/azure/operators/test_powerbi.py +++ b/providers/tests/microsoft/azure/operators/test_powerbi.py @@ -32,7 +32,8 @@ from airflow.utils import timezone from providers.tests.microsoft.azure.base import Base -from providers.tests.microsoft.conftest import get_airflow_connection, mock_context +from providers.tests.microsoft.conftest import get_airflow_connection +from tests_common.test_utils.mock_context import mock_context DEFAULT_CONNECTION_CLIENT_SECRET = "powerbi_conn_id" TASK_ID = "run_powerbi_operator" diff --git a/providers/tests/microsoft/azure/sensors/test_msgraph.py b/providers/tests/microsoft/azure/sensors/test_msgraph.py index 9ad03ccf17020..a44be23072373 100644 --- a/providers/tests/microsoft/azure/sensors/test_msgraph.py +++ b/providers/tests/microsoft/azure/sensors/test_msgraph.py @@ -25,6 +25,7 @@ from providers.tests.microsoft.azure.base import Base from providers.tests.microsoft.conftest import load_json, mock_json_response +from tests_common.test_utils.operators.run_deferable import execute_operator from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS @@ -43,7 +44,7 @@ def test_execute(self): timeout=350.0, ) - results, events = self.execute_operator(sensor) + results, events = execute_operator(sensor) assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"} assert isinstance(results, str) @@ -69,7 +70,7 @@ def test_execute_with_lambda_parameter(self): timeout=350.0, ) - results, events = self.execute_operator(sensor) + results, events = execute_operator(sensor) assert sensor.path_parameters == {"scanId": "0a1b1bf3-37de-48f7-9863-ed4cda97a9ef"} assert isinstance(results, str) diff --git a/providers/tests/microsoft/azure/triggers/test_msgraph.py b/providers/tests/microsoft/azure/triggers/test_msgraph.py index 0784d8d83177c..e822f6c1e2cf9 100644 --- a/providers/tests/microsoft/azure/triggers/test_msgraph.py +++ b/providers/tests/microsoft/azure/triggers/test_msgraph.py @@ -40,6 +40,7 @@ mock_json_response, mock_response, ) +from tests_common.test_utils.operators.run_deferable import run_trigger class TestMSGraphTrigger(Base): @@ -49,7 +50,7 @@ def test_run_when_valid_response(self): with self.patch_hook_and_request_adapter(response): trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api") - actual = self.run_trigger(trigger) + actual = run_trigger(trigger) assert len(actual) == 1 assert isinstance(actual[0], TriggerEvent) @@ -62,7 +63,7 @@ def test_run_when_response_is_none(self): with self.patch_hook_and_request_adapter(response): trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api") - actual = self.run_trigger(trigger) + actual = run_trigger(trigger) assert len(actual) == 1 assert isinstance(actual[0], TriggerEvent) @@ -73,7 +74,7 @@ def test_run_when_response_is_none(self): def test_run_when_response_cannot_be_converted_to_json(self): with self.patch_hook_and_request_adapter(AirflowException()): trigger = MSGraphTrigger("users/delta", conn_id="msgraph_api") - actual = next(iter(self.run_trigger(trigger))) + actual = next(iter(run_trigger(trigger))) assert isinstance(actual, TriggerEvent) assert actual.payload["status"] == "failure" @@ -89,7 +90,7 @@ def test_run_when_response_is_bytes(self): "https://graph.microsoft.com/v1.0/me/drive/items/1b30fecf-4330-4899-b249-104c2afaf9ed/content" ) trigger = MSGraphTrigger(url, response_type="bytes", conn_id="msgraph_api") - actual = next(iter(self.run_trigger(trigger))) + actual = next(iter(run_trigger(trigger))) assert isinstance(actual, TriggerEvent) assert actual.payload["status"] == "success" diff --git a/providers/tests/microsoft/conftest.py b/providers/tests/microsoft/conftest.py index 240f33e335d7e..c68335f3de5e8 100644 --- a/providers/tests/microsoft/conftest.py +++ b/providers/tests/microsoft/conftest.py @@ -21,7 +21,6 @@ import random import re import string -from collections.abc import Iterable from inspect import currentframe from json import JSONDecodeError from os.path import dirname, join @@ -29,15 +28,13 @@ from unittest.mock import MagicMock import pytest -from httpx import Headers, Response -from msgraph_core import APIVersion - from airflow.models import Connection from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook -from airflow.utils.context import Context +from httpx import Headers, Response +from msgraph_core import APIVersion if TYPE_CHECKING: - from sqlalchemy.orm import Session + pass T = TypeVar("T", dict, str, Connection) @@ -112,50 +109,6 @@ def mock_response(status_code, content: Any = None, headers: dict | None = None) return response -def mock_context(task) -> Context: - from airflow.models import TaskInstance - from airflow.utils.session import NEW_SESSION - from airflow.utils.state import TaskInstanceState - from airflow.utils.xcom import XCOM_RETURN_KEY - - values: dict[str, Any] = {} - - class MockedTaskInstance(TaskInstance): - def __init__( - self, - task, - run_id: str | None = "run_id", - state: str | None = TaskInstanceState.RUNNING, - map_index: int = -1, - ): - super().__init__(task=task, run_id=run_id, state=state, map_index=map_index) - self.values: dict[str, Any] = {} - - def xcom_pull( - self, - task_ids: str | Iterable[str] | None = None, - dag_id: str | None = None, - key: str = XCOM_RETURN_KEY, - include_prior_dates: bool = False, - session: Session = NEW_SESSION, - *, - map_indexes: int | Iterable[int] | None = None, - default: Any = None, - run_id: str | None = None, - ) -> Any: - if map_indexes: - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") - - def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: - values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value - - values["ti"] = MockedTaskInstance(task=task) - - # See https://github.com/python/mypy/issues/8890 - mypy does not support passing typed dict to TypedDict - return Context(values) # type: ignore[misc] - - def remove_license_header(content: str) -> str: """ Removes license header from the given content. diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py new file mode 100644 index 0000000000000..de88fdb74f0cb --- /dev/null +++ b/tests_common/test_utils/mock_context.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any + +from airflow.utils.context import Context + + +def mock_context(task) -> Context: + from airflow.models import TaskInstance + from airflow.utils.session import NEW_SESSION + from airflow.utils.state import TaskInstanceState + from airflow.utils.xcom import XCOM_RETURN_KEY + from sqlalchemy.orm import Session + + values: dict[str, Any] = {} + + class MockedTaskInstance(TaskInstance): + def __init__( + self, + task, + run_id: str | None = "run_id", + state: str | None = TaskInstanceState.RUNNING, + map_index: int = -1, + ): + super().__init__(task=task, run_id=run_id, state=state, map_index=map_index) + self.values: dict[str, Any] = {} + + def xcom_pull( + self, + task_ids: str | Iterable[str] | None = None, + dag_id: str | None = None, + key: str = XCOM_RETURN_KEY, + include_prior_dates: bool = False, + session: Session = NEW_SESSION, + *, + map_indexes: int | Iterable[int] | None = None, + default: Any = None, + run_id: str | None = None, + ) -> Any: + if map_indexes: + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") + + def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: + values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value + + values["ti"] = MockedTaskInstance(task=task) + + # See https://github.com/python/mypy/issues/8890 - mypy does not support passing typed dict to TypedDict + return Context(values) # type: ignore[misc] diff --git a/tests_common/test_utils/operators/run_deferable.py b/tests_common/test_utils/operators/run_deferable.py new file mode 100644 index 0000000000000..cdcc781767648 --- /dev/null +++ b/tests_common/test_utils/operators/run_deferable.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from copy import deepcopy +from typing import TYPE_CHECKING, Any + +from airflow.exceptions import TaskDeferred + +from tests_common.test_utils.mock_context import mock_context + +if TYPE_CHECKING: + from airflow.models import Operator + from airflow.triggers.base import BaseTrigger, TriggerEvent + + +async def run_tigger(trigger: BaseTrigger) -> list[TriggerEvent]: + events = [] + async for event in trigger.run(): + events.append(event) + return events + + +def run_trigger(trigger: BaseTrigger) -> list[TriggerEvent]: + return asyncio.run(run_tigger(trigger)) + + +def execute_operator(operator: Operator) -> tuple[Any, Any]: + context = mock_context(task=operator) + return asyncio.run(deferrable_operator(context, operator)) + + +async def deferrable_operator(context, operator): + result = None + triggered_events = [] + try: + operator.render_template_fields(context=context) + result = operator.execute(context=context) + except TaskDeferred as deferred: + task = deferred + + while task: + events = await run_tigger(task.trigger) + + if not events: + break + + triggered_events.extend(deepcopy(events)) + + try: + method = getattr(operator, task.method_name) + result = method(context=context, event=next(iter(events)).payload) + task = None + except TaskDeferred as exception: + task = exception + return result, triggered_events From 38f654a825231d6f5ae2a7edd6ffcc43d736b195 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 13:18:54 +0100 Subject: [PATCH 04/20] refactor: Fixed static checks --- .../airflow/providers/common/sql/triggers/sql.py | 7 +++---- .../standard/operators/generic_transfer.py | 16 ++++------------ providers/tests/common/sql/triggers/test_sql.py | 1 + providers/tests/microsoft/azure/base.py | 3 ++- providers/tests/microsoft/conftest.py | 8 +++----- tests_common/test_utils/mock_context.py | 3 ++- 6 files changed, 15 insertions(+), 23 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.py b/providers/src/airflow/providers/common/sql/triggers/sql.py index 627c9dedede8e..e120759fc86d0 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/src/airflow/providers/common/sql/triggers/sql.py @@ -25,10 +25,8 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: - from typing import ( - Any, - AsyncIterator, - ) + from collections.abc import AsyncIterator + from typing import Any class SQLExecuteQueryTrigger(BaseTrigger): @@ -39,6 +37,7 @@ class SQLExecuteQueryTrigger(BaseTrigger): :param conn_id: the connection ID used to connect to the database :param hook_params: hook parameters """ + def __init__( self, sql: str | list[str], diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index ea2de13e04c1d..bb77d49c148e6 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -109,9 +109,7 @@ def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook: def get_paginated_sql(self, offset: int) -> str: """Format the paginated SQL statement using the current format.""" - return self._paginated_sql_statement_format.format( - self.sql, self.chunk_size, offset - ) + return self._paginated_sql_statement_format.format(self.sql, self.chunk_size, offset) def render_template_fields( self, @@ -191,9 +189,7 @@ def execute_complete( if event.get("status") == "failure": raise AirflowException(event.get("message")) - destination_hook = BaseHook.get_hook( - self.destination_conn_id, hook_params=self.destination_hook_params - ) + destination_hook = BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) results = event.get("results") @@ -221,12 +217,8 @@ def execute_complete( self.log.info("Offset increased to %d", offset) self.xcom_push(context=context, key="offset", value=offset) - self.log.info( - "Inserting %d rows into %s", len(results), self.destination_conn_id - ) - insert_rows( - table=self.destination_table, rows=results, **self.insert_args - ) + self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id) + insert_rows(table=self.destination_table, rows=results, **self.insert_args) self.log.info( "Inserting %d rows into %s done!", len(results), diff --git a/providers/tests/common/sql/triggers/test_sql.py b/providers/tests/common/sql/triggers/test_sql.py index dfa2fd5d5bcc3..ba42c97c57def 100644 --- a/providers/tests/common/sql/triggers/test_sql.py +++ b/providers/tests/common/sql/triggers/test_sql.py @@ -20,6 +20,7 @@ from unittest import mock import pytest + from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger from airflow.triggers.base import TriggerEvent diff --git a/providers/tests/microsoft/azure/base.py b/providers/tests/microsoft/azure/base.py index 7f37b1bb524aa..f22212cc3ea12 100644 --- a/providers/tests/microsoft/azure/base.py +++ b/providers/tests/microsoft/azure/base.py @@ -19,9 +19,10 @@ from contextlib import contextmanager from unittest.mock import patch -from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook from kiota_http.httpx_request_adapter import HttpxRequestAdapter +from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook + from providers.tests.microsoft.conftest import get_airflow_connection diff --git a/providers/tests/microsoft/conftest.py b/providers/tests/microsoft/conftest.py index c68335f3de5e8..e0ec3172cfa2a 100644 --- a/providers/tests/microsoft/conftest.py +++ b/providers/tests/microsoft/conftest.py @@ -24,17 +24,15 @@ from inspect import currentframe from json import JSONDecodeError from os.path import dirname, join -from typing import TYPE_CHECKING, Any, TypeVar +from typing import Any, TypeVar from unittest.mock import MagicMock import pytest -from airflow.models import Connection -from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook from httpx import Headers, Response from msgraph_core import APIVersion -if TYPE_CHECKING: - pass +from airflow.models import Connection +from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIHook T = TypeVar("T", dict, str, Connection) diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index de88fdb74f0cb..a52b9a04113fb 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -23,11 +23,12 @@ def mock_context(task) -> Context: + from sqlalchemy.orm import Session + from airflow.models import TaskInstance from airflow.utils.session import NEW_SESSION from airflow.utils.state import TaskInstanceState from airflow.utils.xcom import XCOM_RETURN_KEY - from sqlalchemy.orm import Session values: dict[str, Any] = {} From 15383e3a831b2efeb6955d4ba650c663646f9944 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 14:58:37 +0100 Subject: [PATCH 05/20] refactor: Fixed static checks --- .../src/airflow/providers/common/sql/triggers/sql.py | 4 ++-- .../providers/standard/operators/generic_transfer.py | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.py b/providers/src/airflow/providers/common/sql/triggers/sql.py index e120759fc86d0..46c5d55cb4441 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/src/airflow/providers/common/sql/triggers/sql.py @@ -25,8 +25,8 @@ from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: - from collections.abc import AsyncIterator - from typing import Any + from collections.abc import AsyncIterator + from typing import Any class SQLExecuteQueryTrigger(BaseTrigger): diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index bb77d49c148e6..974bc225eb140 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -125,9 +125,7 @@ def render_template_fields( self.insert_args["commit_every"] = int(self.insert_args.get("commit_every")) def paginated_execute(self, context: Context): - destination_hook = BaseHook.get_hook( - self.destination_conn_id, hook_params=self.destination_hook_params - ) + destination_hook = BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) if self.preoperator: run = getattr(destination_hook, "run", None) @@ -153,9 +151,7 @@ def paginated_execute(self, context: Context): method_name=self.execute_complete.__name__, ) else: - source_hook = BaseHook.get_hook( - self.source_conn_id, hook_params=self.source_hook_params - ) + source_hook = BaseHook.get_hook(self.source_conn_id, hook_params=self.source_hook_params) self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) From a93b49a00fbc2ed07acd2ccd99caef12a6d04da3 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 15:05:08 +0100 Subject: [PATCH 06/20] refactor: Fixed static checks --- providers/src/airflow/providers/common/sql/triggers/sql.py | 4 ++-- .../airflow/providers/standard/operators/generic_transfer.py | 5 +++-- providers/tests/microsoft/azure/operators/test_msgraph.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.py b/providers/src/airflow/providers/common/sql/triggers/sql.py index 46c5d55cb4441..8c154b1a82f08 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/src/airflow/providers/common/sql/triggers/sql.py @@ -41,7 +41,7 @@ class SQLExecuteQueryTrigger(BaseTrigger): def __init__( self, sql: str | list[str], - conn_id: str | None = None, + conn_id: str, hook_params: dict | None = None, **kwargs, ): @@ -81,7 +81,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: if not callable(get_records): raise RuntimeError( - f"Hook for connection {self.source_conn_id!r} " + f"Hook for connection {self.conn_id!r} " f"({type(hook).__name__}) has no `get_records` method" ) else: diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 974bc225eb140..2d5ad0c8fa0a6 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -121,8 +121,9 @@ def render_template_fields( # Make sure string are converted to integers if isinstance(self.chunk_size, str): self.chunk_size = int(self.chunk_size) - if isinstance(self.insert_args.get("commit_every"), str): - self.insert_args["commit_every"] = int(self.insert_args.get("commit_every")) + commit_every = self.insert_args.get("commit_every") + if isinstance(commit_every, str): + self.insert_args["commit_every"] = int(commit_every) def paginated_execute(self, context: Context): destination_hook = BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py b/providers/tests/microsoft/azure/operators/test_msgraph.py index 13770cf1d034d..670ae0ee7de79 100644 --- a/providers/tests/microsoft/azure/operators/test_msgraph.py +++ b/providers/tests/microsoft/azure/operators/test_msgraph.py @@ -31,10 +31,10 @@ from providers.tests.microsoft.conftest import ( load_file, load_json, - mock_context, mock_json_response, mock_response, ) +from tests_common.test_utils.mock_context import mock_context from tests_common.test_utils.operators.run_deferable import execute_operator from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS From 8a1d2ded6d0d096fae02dd55665cbcee5ccf39a7 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 15:55:28 +0100 Subject: [PATCH 07/20] refactor: Reformatted GenericTransfer --- .../providers/standard/operators/generic_transfer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 2d5ad0c8fa0a6..5c41f7a129afa 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -21,6 +21,7 @@ from typing import TYPE_CHECKING, Any import jinja2 + from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator @@ -126,7 +127,9 @@ def render_template_fields( self.insert_args["commit_every"] = int(commit_every) def paginated_execute(self, context: Context): - destination_hook = BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) + destination_hook = BaseHook.get_hook( + self.destination_conn_id, hook_params=self.destination_hook_params + ) if self.preoperator: run = getattr(destination_hook, "run", None) @@ -152,7 +155,9 @@ def paginated_execute(self, context: Context): method_name=self.execute_complete.__name__, ) else: - source_hook = BaseHook.get_hook(self.source_conn_id, hook_params=self.source_hook_params) + source_hook = BaseHook.get_hook( + self.source_conn_id, hook_params=self.source_hook_params + ) self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) From 757ab68d3a1aaacaf9cd9c42ba89106bfb0b2ee2 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 10 Dec 2024 16:28:27 +0100 Subject: [PATCH 08/20] refactor: Moved source and destination hooks into cached properties --- .../standard/operators/generic_transfer.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 5c41f7a129afa..a59ccda2df760 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -18,6 +18,7 @@ from __future__ import annotations from collections.abc import Sequence +from functools import cached_property from typing import TYPE_CHECKING, Any import jinja2 @@ -126,17 +127,21 @@ def render_template_fields( if isinstance(commit_every, str): self.insert_args["commit_every"] = int(commit_every) - def paginated_execute(self, context: Context): - destination_hook = BaseHook.get_hook( - self.destination_conn_id, hook_params=self.destination_hook_params - ) + @cached_property + def source_hook(self) -> BaseHook: + return BaseHook.get_hook(self.source_conn_id, hook_params=self.source_hook_params) + @cached_property + def destination_hook(self) -> BaseHook: + return BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) + + def paginated_execute(self, context: Context): if self.preoperator: - run = getattr(destination_hook, "run", None) + run = getattr(self.destination_hook, "run", None) if not callable(run): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " - f"({type(destination_hook).__name__}) has no `run` method" + f"({type(self.destination_hook).__name__}) has no `run` method" ) self.log.info("Running preoperator") self.log.info(self.preoperator) @@ -155,28 +160,24 @@ def paginated_execute(self, context: Context): method_name=self.execute_complete.__name__, ) else: - source_hook = BaseHook.get_hook( - self.source_conn_id, hook_params=self.source_hook_params - ) - self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) - get_records = getattr(source_hook, "get_records", None) + get_records = getattr(self.source_hook, "get_records", None) if not callable(get_records): raise RuntimeError( f"Hook for connection {self.source_conn_id!r} " - f"({type(source_hook).__name__}) has no `get_records` method" + f"({type(self.source_hook).__name__}) has no `get_records` method" ) results = get_records(self.sql) - insert_rows = getattr(destination_hook, "insert_rows", None) + insert_rows = getattr(self.destination_hook, "insert_rows", None) if not callable(insert_rows): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " - f"({type(destination_hook).__name__}) has no `insert_rows` method" + f"({type(self.destination_hook).__name__}) has no `insert_rows` method" ) self.log.info("Inserting rows into %s", self.destination_conn_id) @@ -191,17 +192,15 @@ def execute_complete( if event.get("status") == "failure": raise AirflowException(event.get("message")) - destination_hook = BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) - results = event.get("results") if results: - insert_rows = getattr(destination_hook, "insert_rows", None) + insert_rows = getattr(self.destination_hook, "insert_rows", None) if not callable(insert_rows): raise RuntimeError( f"Hook for connection {self.destination_conn_id!r} " - f"({type(destination_hook).__name__}) has no `insert_rows` method" + f"({type(self.destination_hook).__name__}) has no `insert_rows` method" ) map_index = context["ti"].map_index From 0744383a61150aaa3685975e6203a2549399cc0f Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Dec 2024 08:10:15 +0100 Subject: [PATCH 09/20] refactor: Moved imports to type checking block --- .../providers/standard/operators/generic_transfer.py | 4 ++-- tests_common/test_utils/mock_context.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index a59ccda2df760..f93d793a7b71c 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -21,14 +21,14 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -import jinja2 - from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger if TYPE_CHECKING: + import jinja2 + from airflow.utils.context import Context diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index a52b9a04113fb..c8a259e684d43 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -17,13 +17,16 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any +from typing import TYPE_CHECKING, Any from airflow.utils.context import Context +if TYPE_CHECKING: + from sqlalchemy.orm import Session + def mock_context(task) -> Context: - from sqlalchemy.orm import Session + from airflow.models import TaskInstance from airflow.utils.session import NEW_SESSION From edff5f443f75fc15ab7599180dc5e51ae878325a Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Dec 2024 08:29:48 +0100 Subject: [PATCH 10/20] refactor: Fixed execute method of GenericTransfer --- .../airflow/providers/standard/operators/generic_transfer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index f93d793a7b71c..8550bd52f4d6d 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -135,7 +135,7 @@ def source_hook(self) -> BaseHook: def destination_hook(self) -> BaseHook: return BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) - def paginated_execute(self, context: Context): + def execute(self, context: Context): if self.preoperator: run = getattr(self.destination_hook, "run", None) if not callable(run): @@ -147,9 +147,6 @@ def paginated_execute(self, context: Context): self.log.info(self.preoperator) run(self.preoperator) - if isinstance(self.insert_args.get("commit_every"), str): - self.insert_args["commit_every"] = int(self.insert_args.get("commit_every")) - if self.chunk_size and isinstance(self.sql, str): self.defer( trigger=SQLExecuteQueryTrigger( From f3b2893d73b3cadfa5f864213f82abce134ce969 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Dec 2024 08:42:16 +0100 Subject: [PATCH 11/20] refactor: Refactored get_hook method of GenericTransfer which checks if hook is instance of DbApiHook instead of checking presence of get_records and insert_rows method --- .../standard/operators/generic_transfer.py | 69 +++++++------------ 1 file changed, 24 insertions(+), 45 deletions(-) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 8550bd52f4d6d..6f53ea3482193 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -24,6 +24,7 @@ from airflow.exceptions import AirflowException from airflow.hooks.base import BaseHook from airflow.models import BaseOperator +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger if TYPE_CHECKING: @@ -98,16 +99,29 @@ def __init__( ) @classmethod - def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> BaseHook: + def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook: """ - Return default hook for this connection id. + Return DbApiHook for this connection id. :param conn_id: connection id :param hook_params: hook parameters - :return: default hook for this connection + :return: DbApiHook for this connection """ connection = BaseHook.get_connection(conn_id) - return connection.get_hook(hook_params=hook_params) + hook = connection.get_hook(hook_params=hook_params) + if not isinstance(hook, DbApiHook): + raise RuntimeError( + f"Hook for connection {conn_id!r} must be of type {DbApiHook.__name__}" + ) + return hook + + @cached_property + def source_hook(self) -> DbApiHook: + return self.get_hook(conn_id=self.source_conn_id, hook_params=self.source_hook_params) + + @cached_property + def destination_hook(self) -> DbApiHook: + return self.get_hook(conn_id=self.destination_conn_id, hook_params=self.destination_hook_params) def get_paginated_sql(self, offset: int) -> str: """Format the paginated SQL statement using the current format.""" @@ -127,25 +141,11 @@ def render_template_fields( if isinstance(commit_every, str): self.insert_args["commit_every"] = int(commit_every) - @cached_property - def source_hook(self) -> BaseHook: - return BaseHook.get_hook(self.source_conn_id, hook_params=self.source_hook_params) - - @cached_property - def destination_hook(self) -> BaseHook: - return BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params) - def execute(self, context: Context): if self.preoperator: - run = getattr(self.destination_hook, "run", None) - if not callable(run): - raise RuntimeError( - f"Hook for connection {self.destination_conn_id!r} " - f"({type(self.destination_hook).__name__}) has no `run` method" - ) self.log.info("Running preoperator") self.log.info(self.preoperator) - run(self.preoperator) + self.destination_hook.run(self.preoperator) if self.chunk_size and isinstance(self.sql, str): self.defer( @@ -160,25 +160,10 @@ def execute(self, context: Context): self.log.info("Extracting data from %s", self.source_conn_id) self.log.info("Executing: \n %s", self.sql) - get_records = getattr(self.source_hook, "get_records", None) - - if not callable(get_records): - raise RuntimeError( - f"Hook for connection {self.source_conn_id!r} " - f"({type(self.source_hook).__name__}) has no `get_records` method" - ) - - results = get_records(self.sql) - insert_rows = getattr(self.destination_hook, "insert_rows", None) - - if not callable(insert_rows): - raise RuntimeError( - f"Hook for connection {self.destination_conn_id!r} " - f"({type(self.destination_hook).__name__}) has no `insert_rows` method" - ) + results = self.destination_hook.get_records(self.sql) self.log.info("Inserting rows into %s", self.destination_conn_id) - insert_rows(table=self.destination_table, rows=results, **self.insert_args) + self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args) def execute_complete( self, @@ -192,14 +177,6 @@ def execute_complete( results = event.get("results") if results: - insert_rows = getattr(self.destination_hook, "insert_rows", None) - - if not callable(insert_rows): - raise RuntimeError( - f"Hook for connection {self.destination_conn_id!r} " - f"({type(self.destination_hook).__name__}) has no `insert_rows` method" - ) - map_index = context["ti"].map_index offset = ( context["ti"].xcom_pull( @@ -216,7 +193,9 @@ def execute_complete( self.xcom_push(context=context, key="offset", value=offset) self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id) - insert_rows(table=self.destination_table, rows=results, **self.insert_args) + self.destination_hook.insert_rows( + table=self.destination_table, rows=results, **self.insert_args + ) self.log.info( "Inserting %d rows into %s done!", len(results), From ac4df0283bda8974a64176c787a9693d45f5fd3a Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Dec 2024 09:05:37 +0100 Subject: [PATCH 12/20] refactor: Remove white lines from mock_context --- tests_common/test_utils/mock_context.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index c8a259e684d43..5a17e058f5f11 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -26,8 +26,6 @@ def mock_context(task) -> Context: - - from airflow.models import TaskInstance from airflow.utils.session import NEW_SESSION from airflow.utils.state import TaskInstanceState From 0a771d85ec0973b34338338c0183817ec1a1aafc Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 11 Dec 2024 09:06:16 +0100 Subject: [PATCH 13/20] refactor: Reformatted get_hook in GenericTransfer operator --- .../airflow/providers/standard/operators/generic_transfer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/standard/operators/generic_transfer.py index 6f53ea3482193..1a9b5af9b6264 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/standard/operators/generic_transfer.py @@ -110,9 +110,7 @@ def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook: connection = BaseHook.get_connection(conn_id) hook = connection.get_hook(hook_params=hook_params) if not isinstance(hook, DbApiHook): - raise RuntimeError( - f"Hook for connection {conn_id!r} must be of type {DbApiHook.__name__}" - ) + raise RuntimeError(f"Hook for connection {conn_id!r} must be of type {DbApiHook.__name__}") return hook @cached_property From 0e426dcaf01f1871c9f6071e2ef92395ac72916e Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 11:52:14 +0100 Subject: [PATCH 14/20] refactor: Added sql.pyi for SQLExecuteQueryTrigger --- .../providers/common/sql/triggers/sql.pyi | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 providers/src/airflow/providers/common/sql/triggers/sql.pyi diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.pyi b/providers/src/airflow/providers/common/sql/triggers/sql.pyi new file mode 100644 index 0000000000000..c333364d74f82 --- /dev/null +++ b/providers/src/airflow/providers/common/sql/triggers/sql.pyi @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# This is automatically generated stub for the `common.sql` provider +# +# This file is generated automatically by the `update-common-sql-api stubs` pre-commit +# and the .pyi file represents part of the "public" API that the +# `common.sql` provider exposes to other providers. +# +# Any, potentially breaking change in the stubs will require deliberate manual action from the contributor +# making a change to the `common.sql` provider. Those stubs are also used by MyPy automatically when checking +# if only public API of the common.sql provider is used by all the other providers. +# +# You can read more in the README_API.md file +# +""" +Definition of the public interface for airflow.providers.common.sql.sensors.sql +isort:skip_file +""" +from collections.abc import AsyncIterator +from typing import Any + +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class SQLExecuteQueryTrigger(BaseTrigger): + def __init__( + self, + sql: str | list[str], + conn_id: str, + hook_params: dict | None = None, + **kwargs, + ) -> None: ... + + def serialize(self) -> tuple[str, dict[str, Any]]: ... + + async def run(self) -> AsyncIterator[TriggerEvent]:... From 4bc7d6f9a435bf9d53d02734f361b757db1d5def Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 12:16:56 +0100 Subject: [PATCH 15/20] refactor: Reformatted SQLExecuteQueryTrigger definition --- .../airflow/providers/common/sql/triggers/sql.pyi | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.pyi b/providers/src/airflow/providers/common/sql/triggers/sql.pyi index c333364d74f82..ba820099e5b5c 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.pyi +++ b/providers/src/airflow/providers/common/sql/triggers/sql.pyi @@ -28,24 +28,20 @@ # You can read more in the README_API.md file # """ -Definition of the public interface for airflow.providers.common.sql.sensors.sql +Definition of the public interface for airflow.providers.common.sql.triggers.sql isort:skip_file """ +from airflow.triggers.base import BaseTrigger, TriggerEvent + from collections.abc import AsyncIterator from typing import Any -from airflow.triggers.base import BaseTrigger, TriggerEvent - class SQLExecuteQueryTrigger(BaseTrigger): def __init__( - self, - sql: str | list[str], - conn_id: str, - hook_params: dict | None = None, - **kwargs, + self, sql: str | list[str], conn_id: str, hook_params: dict | None = None, **kwargs, ) -> None: ... def serialize(self) -> tuple[str, dict[str, Any]]: ... - async def run(self) -> AsyncIterator[TriggerEvent]:... + async def run(self) -> AsyncIterator[TriggerEvent]: ... From 59eae357d9ec4198a198fcaa5afd279ec889b5fc Mon Sep 17 00:00:00 2001 From: David Blain Date: Mon, 16 Dec 2024 12:49:08 +0100 Subject: [PATCH 16/20] refactor: Added alias in SQLExecuteQueryTrigger definition --- providers/src/airflow/providers/common/sql/triggers/sql.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.pyi b/providers/src/airflow/providers/common/sql/triggers/sql.pyi index ba820099e5b5c..295d36e5eb1f5 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.pyi +++ b/providers/src/airflow/providers/common/sql/triggers/sql.pyi @@ -31,7 +31,7 @@ Definition of the public interface for airflow.providers.common.sql.triggers.sql isort:skip_file """ -from airflow.triggers.base import BaseTrigger, TriggerEvent +from airflow.triggers.base import BaseTrigger, TriggerEvent as TriggerEvent from collections.abc import AsyncIterator from typing import Any From f0925073b314ebc82c43a89686da743525411f45 Mon Sep 17 00:00:00 2001 From: David Blain Date: Tue, 7 Jan 2025 13:06:34 +0100 Subject: [PATCH 17/20] refactor: Added unit test for GenericTransfer using deferred pageable reads --- .../sql}/operators/generic_transfer.py | 16 ++--- .../providers/common/sql/triggers/sql.py | 15 +---- .../sql}/operators/test_generic_transfer.py | 67 ++++++++++++++++++- tests_common/test_utils/compat.py | 2 +- tests_common/test_utils/mock_context.py | 4 +- 5 files changed, 79 insertions(+), 25 deletions(-) rename providers/src/airflow/providers/{standard => common/sql}/operators/generic_transfer.py (95%) rename providers/tests/{standard => common/sql}/operators/test_generic_transfer.py (75%) diff --git a/providers/src/airflow/providers/standard/operators/generic_transfer.py b/providers/src/airflow/providers/common/sql/operators/generic_transfer.py similarity index 95% rename from providers/src/airflow/providers/standard/operators/generic_transfer.py rename to providers/src/airflow/providers/common/sql/operators/generic_transfer.py index 1a9b5af9b6264..74ac34d44ec2f 100644 --- a/providers/src/airflow/providers/standard/operators/generic_transfer.py +++ b/providers/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -52,7 +52,7 @@ class GenericTransfer(BaseOperator): :param preoperator: sql statement or list of statements to be executed prior to loading the data. (templated) :param insert_args: extra params for `insert_rows` method. - :param chunk_size: number of records to be read in paginated mode (optional). + :param page_size: number of records to be read in paginated mode (optional). """ template_fields: Sequence[str] = ( @@ -81,7 +81,7 @@ def __init__( destination_hook_params: dict | None = None, preoperator: str | list[str] | None = None, insert_args: dict | None = None, - chunk_size: int | None = None, + page_size: int | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -93,7 +93,7 @@ def __init__( self.destination_hook_params = destination_hook_params self.preoperator = preoperator self.insert_args = insert_args or {} - self.chunk_size = chunk_size + self.page_size = page_size self._paginated_sql_statement_format = kwargs.get( "paginated_sql_statement_format", "{} LIMIT {} OFFSET {}" ) @@ -123,7 +123,7 @@ def destination_hook(self) -> DbApiHook: def get_paginated_sql(self, offset: int) -> str: """Format the paginated SQL statement using the current format.""" - return self._paginated_sql_statement_format.format(self.sql, self.chunk_size, offset) + return self._paginated_sql_statement_format.format(self.sql, self.page_size, offset) def render_template_fields( self, @@ -133,8 +133,8 @@ def render_template_fields( super().render_template_fields(context=context, jinja_env=jinja_env) # Make sure string are converted to integers - if isinstance(self.chunk_size, str): - self.chunk_size = int(self.chunk_size) + if isinstance(self.page_size, str): + self.page_size = int(self.page_size) commit_every = self.insert_args.get("commit_every") if isinstance(commit_every, str): self.insert_args["commit_every"] = int(commit_every) @@ -145,7 +145,7 @@ def execute(self, context: Context): self.log.info(self.preoperator) self.destination_hook.run(self.preoperator) - if self.chunk_size and isinstance(self.sql, str): + if self.page_size and isinstance(self.sql, str): self.defer( trigger=SQLExecuteQueryTrigger( conn_id=self.source_conn_id, @@ -184,7 +184,7 @@ def execute_complete( map_indexes=map_index, default=0, ) - + self.chunk_size + + self.page_size ) self.log.info("Offset increased to %d", offset) diff --git a/providers/src/airflow/providers/common/sql/triggers/sql.py b/providers/src/airflow/providers/common/sql/triggers/sql.py index 8c154b1a82f08..ee345722154e5 100644 --- a/providers/src/airflow/providers/common/sql/triggers/sql.py +++ b/providers/src/airflow/providers/common/sql/triggers/sql.py @@ -76,18 +76,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.log.info("Extracting data from %s", self.conn_id) self.log.info("Executing: \n %s", self.sql) - - get_records = getattr(hook, "get_records", None) - - if not callable(get_records): - raise RuntimeError( - f"Hook for connection {self.conn_id!r} " - f"({type(hook).__name__}) has no `get_records` method" - ) - else: - self.log.info("Reading records from %s", self.conn_id) - results = get_records(self.sql) - self.log.info("Reading records from %s done!", self.conn_id) + self.log.info("Reading records from %s", self.conn_id) + results = hook.get_records(self.sql) + self.log.info("Reading records from %s done!", self.conn_id) self.log.debug("results: %s", results) yield TriggerEvent({"status": "success", "results": results}) diff --git a/providers/tests/standard/operators/test_generic_transfer.py b/providers/tests/common/sql/operators/test_generic_transfer.py similarity index 75% rename from providers/tests/standard/operators/test_generic_transfer.py rename to providers/tests/common/sql/operators/test_generic_transfer.py index 4ea08e48891e6..4c26cc8f941e0 100644 --- a/providers/tests/standard/operators/test_generic_transfer.py +++ b/providers/tests/common/sql/operators/test_generic_transfer.py @@ -19,17 +19,20 @@ import inspect from contextlib import closing -from datetime import datetime +from datetime import datetime, timedelta from unittest import mock +from unittest.mock import MagicMock import pytest - from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.models.connection import Connection from airflow.models.dag import DAG +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.postgres.hooks.postgres import PostgresHook from airflow.utils import timezone from tests_common.test_utils.compat import GenericTransfer +from tests_common.test_utils.operators.run_deferable import execute_operator from tests_common.test_utils.providers import get_provider_min_airflow_version pytestmark = pytest.mark.db_test @@ -38,6 +41,7 @@ DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat() DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10] TEST_DAG_ID = "unit_test_dag" +counter = 0 @pytest.mark.backend("mysql") @@ -193,6 +197,65 @@ def test_templated_fields(self): assert operator.preoperator == "my_preoperator" assert operator.insert_args == {"commit_every": 5000, "executemany": True, "replace": True} + def test_paginated_read(self): + """ + This unit test is based on the example described in the medium article: + https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f + """ + + def create_get_records_side_effect(): + records = [ + [[1, 2], [11, 12], [3, 4], [13, 14]], + [[3, 4], [13, 14]], + ] + + def side_effect(sql: str): + if records: + return records.pop(0) + return [] + + return side_effect + + get_records_side_effect = create_get_records_side_effect() + + def get_hook(conn_id: str, hook_params: dict | None = None): + mocked_hook = MagicMock(conn_name_attr=conn_id, spec=DbApiHook) + mocked_hook.get_records.side_effect = get_records_side_effect + return mocked_hook + + def get_connection(conn_id: str): + mocked_hook = get_hook(conn_id=conn_id) + mocked_conn = MagicMock(conn_id=conn_id, spec=Connection) + mocked_conn.get_hook.return_value = mocked_hook + return mocked_conn + + with mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_connection): + with mock.patch("airflow.hooks.base.BaseHook.get_hook", side_effect=get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql="SELECT * FROM HR.EMPLOYEES", + destination_table="NEW_HR.EMPLOYEES", + page_size=1000, # Fetch data in chunks of 1000 rows for pagination + insert_args={ + "commit_every": 1000, # Number of rows inserted in each batch + "executemany": True, # Enable batch inserts + "fast_executemany": True, # Boost performance for MSSQL inserts + "replace": True, # Used for upserts/merges if needed + }, + execution_timeout=timedelta(hours=1), + ) + + results, events = execute_operator(operator) + + assert not results + assert len(events) == 3 + assert events[0].payload["results"] == [[1, 2], [11, 12], [3, 4], [13, 14]] + assert events[1].payload["results"] == [[3, 4], [13, 14]] + assert not events[2].payload["results"] + + def test_when_provider_min_airflow_version_is_3_0_or_higher_remove_obsolete_method(self): """ Once this test starts failing due to the fact that the minimum Airflow version is now 3.0.0 or higher diff --git a/tests_common/test_utils/compat.py b/tests_common/test_utils/compat.py index 3bd4b89dfc1c4..ae9e009a57a8a 100644 --- a/tests_common/test_utils/compat.py +++ b/tests_common/test_utils/compat.py @@ -43,8 +43,8 @@ from airflow.models.baseoperator import BaseOperatorLink try: + from airflow.providers.common.sql.operators.generic_transfer import GenericTransfer from airflow.providers.standard.operators.bash import BashOperator - from airflow.providers.standard.operators.generic_transfer import GenericTransfer from airflow.providers.standard.operators.python import PythonOperator from airflow.providers.standard.sensors.bash import BashSensor from airflow.providers.standard.sensors.date_time import DateTimeSensor diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index 5a17e058f5f11..0db6aa99e5777 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -57,8 +57,8 @@ def xcom_pull( run_id: str | None = None, ) -> Any: if map_indexes: - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}") - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}") + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default) + return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default) def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] = value From 9e755250c8804aae9dab12c2b126a964a04850e7 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 8 Jan 2025 10:50:59 +0100 Subject: [PATCH 18/20] refactor: Moved generic_transfer operator definition in provider.yaml from standard provider to common-sql provider and removed dependency of common-sql provider in standard provider as normally this won't be needed anymore --- providers/src/airflow/providers/common/sql/provider.yaml | 1 + providers/src/airflow/providers/standard/provider.yaml | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/providers/src/airflow/providers/common/sql/provider.yaml b/providers/src/airflow/providers/common/sql/provider.yaml index 0760ea45732b7..f5cc6d60add66 100644 --- a/providers/src/airflow/providers/common/sql/provider.yaml +++ b/providers/src/airflow/providers/common/sql/provider.yaml @@ -91,6 +91,7 @@ operators: - integration-name: Common SQL python-modules: - airflow.providers.common.sql.operators.sql + - airflow.providers.common.sql.operators.generic_transfer dialects: - dialect-type: default diff --git a/providers/src/airflow/providers/standard/provider.yaml b/providers/src/airflow/providers/standard/provider.yaml index 0915b96cf7520..1afdd1d15b210 100644 --- a/providers/src/airflow/providers/standard/provider.yaml +++ b/providers/src/airflow/providers/standard/provider.yaml @@ -30,7 +30,6 @@ versions: dependencies: - apache-airflow>=2.9.0 - - apache-airflow-providers-common-sql>=1.20.0 integrations: - integration-name: Standard @@ -48,7 +47,6 @@ operators: - airflow.providers.standard.operators.weekday - airflow.providers.standard.operators.bash - airflow.providers.standard.operators.python - - airflow.providers.standard.operators.generic_transfer - airflow.providers.standard.operators.trigger_dagrun - airflow.providers.standard.operators.latest_only From 5c52e4ac266840023854d79b988f70f88f1e180c Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 8 Jan 2025 11:12:03 +0100 Subject: [PATCH 19/20] refactor: Renamed typo of module which allows you to run deferrable operators within unit tests --- providers/tests/common/sql/operators/test_generic_transfer.py | 2 +- providers/tests/common/sql/triggers/test_sql.py | 2 +- providers/tests/microsoft/azure/operators/test_msgraph.py | 2 +- providers/tests/microsoft/azure/sensors/test_msgraph.py | 2 +- providers/tests/microsoft/azure/triggers/test_msgraph.py | 2 +- .../operators/{run_deferable.py => run_deferrable.py} | 0 6 files changed, 5 insertions(+), 5 deletions(-) rename tests_common/test_utils/operators/{run_deferable.py => run_deferrable.py} (100%) diff --git a/providers/tests/common/sql/operators/test_generic_transfer.py b/providers/tests/common/sql/operators/test_generic_transfer.py index 4c26cc8f941e0..df43784d0b997 100644 --- a/providers/tests/common/sql/operators/test_generic_transfer.py +++ b/providers/tests/common/sql/operators/test_generic_transfer.py @@ -32,7 +32,7 @@ from airflow.utils import timezone from tests_common.test_utils.compat import GenericTransfer -from tests_common.test_utils.operators.run_deferable import execute_operator +from tests_common.test_utils.operators.run_deferrable import execute_operator from tests_common.test_utils.providers import get_provider_min_airflow_version pytestmark = pytest.mark.db_test diff --git a/providers/tests/common/sql/triggers/test_sql.py b/providers/tests/common/sql/triggers/test_sql.py index ba42c97c57def..75cd8b2c60579 100644 --- a/providers/tests/common/sql/triggers/test_sql.py +++ b/providers/tests/common/sql/triggers/test_sql.py @@ -26,7 +26,7 @@ from airflow.triggers.base import TriggerEvent from providers.tests.microsoft.azure.base import Base -from tests_common.test_utils.operators.run_deferable import run_trigger +from tests_common.test_utils.operators.run_deferrable import run_trigger from tests_common.test_utils.version_compat import AIRFLOW_V_2_9_PLUS pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_9_PLUS, reason="Tests for Airflow 2.8.0+ only") diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py b/providers/tests/microsoft/azure/operators/test_msgraph.py index 670ae0ee7de79..16d3041cfc90c 100644 --- a/providers/tests/microsoft/azure/operators/test_msgraph.py +++ b/providers/tests/microsoft/azure/operators/test_msgraph.py @@ -35,7 +35,7 @@ mock_response, ) from tests_common.test_utils.mock_context import mock_context -from tests_common.test_utils.operators.run_deferable import execute_operator +from tests_common.test_utils.operators.run_deferrable import execute_operator from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS if TYPE_CHECKING: diff --git a/providers/tests/microsoft/azure/sensors/test_msgraph.py b/providers/tests/microsoft/azure/sensors/test_msgraph.py index a44be23072373..78574c9d62715 100644 --- a/providers/tests/microsoft/azure/sensors/test_msgraph.py +++ b/providers/tests/microsoft/azure/sensors/test_msgraph.py @@ -25,7 +25,7 @@ from providers.tests.microsoft.azure.base import Base from providers.tests.microsoft.conftest import load_json, mock_json_response -from tests_common.test_utils.operators.run_deferable import execute_operator +from tests_common.test_utils.operators.run_deferrable import execute_operator from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS diff --git a/providers/tests/microsoft/azure/triggers/test_msgraph.py b/providers/tests/microsoft/azure/triggers/test_msgraph.py index e822f6c1e2cf9..cc4f74ce7ffe1 100644 --- a/providers/tests/microsoft/azure/triggers/test_msgraph.py +++ b/providers/tests/microsoft/azure/triggers/test_msgraph.py @@ -40,7 +40,7 @@ mock_json_response, mock_response, ) -from tests_common.test_utils.operators.run_deferable import run_trigger +from tests_common.test_utils.operators.run_deferrable import run_trigger class TestMSGraphTrigger(Base): diff --git a/tests_common/test_utils/operators/run_deferable.py b/tests_common/test_utils/operators/run_deferrable.py similarity index 100% rename from tests_common/test_utils/operators/run_deferable.py rename to tests_common/test_utils/operators/run_deferrable.py From 73f5f1c79b03ff07d7f6faff681cc26c14787b88 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 8 Jan 2025 11:54:16 +0100 Subject: [PATCH 20/20] refactor: Reformatted xcom_pull method from mock_context --- tests_common/test_utils/mock_context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests_common/test_utils/mock_context.py b/tests_common/test_utils/mock_context.py index 0db6aa99e5777..3c8a101ce2058 100644 --- a/tests_common/test_utils/mock_context.py +++ b/tests_common/test_utils/mock_context.py @@ -57,7 +57,9 @@ def xcom_pull( run_id: str | None = None, ) -> Any: if map_indexes: - return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default) + return values.get( + f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}_{map_indexes}", default + ) return values.get(f"{task_ids or self.task_id}_{dag_id or self.dag_id}_{key}", default) def xcom_push(self, key: str, value: Any, session: Session = NEW_SESSION, **kwargs) -> None: