Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add deferred pagination mode to GenericTransfer #44809

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8f3657f
refactor: Refactored the GenericTransfer operator to support paginate…
davidblain-infrabel Dec 10, 2024
9b36a14
refactor: updated provider dependencies
dabla Dec 10, 2024
eedcd52
refactor: Added TestSQLExecuteQueryTrigger and moved test code which …
davidblain-infrabel Dec 10, 2024
38f654a
refactor: Fixed static checks
davidblain-infrabel Dec 10, 2024
15383e3
refactor: Fixed static checks
davidblain-infrabel Dec 10, 2024
a93b49a
refactor: Fixed static checks
davidblain-infrabel Dec 10, 2024
8a1d2de
refactor: Reformatted GenericTransfer
davidblain-infrabel Dec 10, 2024
757ab68
refactor: Moved source and destination hooks into cached properties
davidblain-infrabel Dec 10, 2024
0744383
refactor: Moved imports to type checking block
davidblain-infrabel Dec 11, 2024
edff5f4
refactor: Fixed execute method of GenericTransfer
davidblain-infrabel Dec 11, 2024
f3b2893
refactor: Refactored get_hook method of GenericTransfer which checks …
davidblain-infrabel Dec 11, 2024
ac4df02
refactor: Remove white lines from mock_context
davidblain-infrabel Dec 11, 2024
0a771d8
refactor: Reformatted get_hook in GenericTransfer operator
davidblain-infrabel Dec 11, 2024
0e426dc
refactor: Added sql.pyi for SQLExecuteQueryTrigger
davidblain-infrabel Dec 16, 2024
4bc7d6f
refactor: Reformatted SQLExecuteQueryTrigger definition
davidblain-infrabel Dec 16, 2024
59eae35
refactor: Added alias in SQLExecuteQueryTrigger definition
davidblain-infrabel Dec 16, 2024
fff10ce
Merge branch 'main' into feature/paginated-generic-transfer
dabla Jan 2, 2025
1662d2f
Merge branch 'main' into feature/paginated-generic-transfer
dabla Jan 2, 2025
735a557
Merge branch 'main' into feature/paginated-generic-transfer
dabla Jan 7, 2025
f092507
refactor: Added unit test for GenericTransfer using deferred pageable…
davidblain-infrabel Jan 7, 2025
3550b84
Merge branch 'main' into feature/paginated-generic-transfer
dabla Jan 7, 2025
44a6965
Merge branch 'main' into feature/paginated-generic-transfer
dabla Jan 8, 2025
9e75525
refactor: Moved generic_transfer operator definition in provider.yaml…
davidblain-infrabel Jan 8, 2025
5c52e4a
refactor: Renamed typo of module which allows you to run deferrable o…
davidblain-infrabel Jan 8, 2025
73f5f1c
refactor: Reformatted xcom_pull method from mock_context
davidblain-infrabel Jan 8, 2025
abadb0b
Merge branch 'main' into feature/paginated-generic-transfer
dabla Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -1286,7 +1286,9 @@
],
"devel-deps": [],
"plugins": [],
"cross-providers-deps": [],
"cross-providers-deps": [
"common.sql"
],
"excluded-python-versions": [],
"state": "ready"
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger

if TYPE_CHECKING:
import jinja2

from airflow.utils.context import Context


class GenericTransfer(BaseOperator):
"""
Moves data from a connection to another.

Assuming that they both provide the required methods in their respective hooks.
The source hook needs to expose a `get_records` method, and the destination a
`insert_rows` method.

This is meant to be used on small-ish datasets that fit in memory.

:param sql: SQL query to execute against the source database. (templated)
:param destination_table: target table. (templated)
:param source_conn_id: source connection. (templated)
:param source_hook_params: source hook parameters.
:param destination_conn_id: destination connection. (templated)
:param destination_hook_params: destination hook parameters.
:param preoperator: sql statement or list of statements to be
executed prior to loading the data. (templated)
:param insert_args: extra params for `insert_rows` method.
:param page_size: number of records to be read in paginated mode (optional).
"""

template_fields: Sequence[str] = (
"source_conn_id",
"destination_conn_id",
"sql",
"destination_table",
"preoperator",
"insert_args",
)
template_ext: Sequence[str] = (
".sql",
".hql",
)
template_fields_renderers = {"preoperator": "sql"}
ui_color = "#b0f07c"

def __init__(
self,
*,
sql: str,
destination_table: str,
source_conn_id: str,
source_hook_params: dict | None = None,
destination_conn_id: str,
destination_hook_params: dict | None = None,
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
page_size: int | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.sql = sql
self.destination_table = destination_table
self.source_conn_id = source_conn_id
self.source_hook_params = source_hook_params
self.destination_conn_id = destination_conn_id
self.destination_hook_params = destination_hook_params
self.preoperator = preoperator
self.insert_args = insert_args or {}
self.page_size = page_size
self._paginated_sql_statement_format = kwargs.get(
"paginated_sql_statement_format", "{} LIMIT {} OFFSET {}"
)

@classmethod
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook:
"""
Return DbApiHook for this connection id.

:param conn_id: connection id
:param hook_params: hook parameters
:return: DbApiHook for this connection
"""
connection = BaseHook.get_connection(conn_id)
hook = connection.get_hook(hook_params=hook_params)
if not isinstance(hook, DbApiHook):
raise RuntimeError(f"Hook for connection {conn_id!r} must be of type {DbApiHook.__name__}")
return hook

@cached_property
def source_hook(self) -> DbApiHook:
return self.get_hook(conn_id=self.source_conn_id, hook_params=self.source_hook_params)

@cached_property
def destination_hook(self) -> DbApiHook:
return self.get_hook(conn_id=self.destination_conn_id, hook_params=self.destination_hook_params)

def get_paginated_sql(self, offset: int) -> str:
"""Format the paginated SQL statement using the current format."""
return self._paginated_sql_statement_format.format(self.sql, self.page_size, offset)

def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
super().render_template_fields(context=context, jinja_env=jinja_env)

# Make sure string are converted to integers
if isinstance(self.page_size, str):
self.page_size = int(self.page_size)
commit_every = self.insert_args.get("commit_every")
if isinstance(commit_every, str):
self.insert_args["commit_every"] = int(commit_every)

def execute(self, context: Context):
if self.preoperator:
self.log.info("Running preoperator")
self.log.info(self.preoperator)
self.destination_hook.run(self.preoperator)

if self.page_size and isinstance(self.sql, str):
self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
hook_params=self.source_hook_params,
sql=self.get_paginated_sql(0),
),
method_name=self.execute_complete.__name__,
)
else:
self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)

results = self.destination_hook.get_records(self.sql)

self.log.info("Inserting rows into %s", self.destination_conn_id)
self.destination_hook.insert_rows(table=self.destination_table, rows=results, **self.insert_args)

def execute_complete(
self,
context: Context,
event: dict[Any, Any] | None = None,
) -> Any:
if event:
if event.get("status") == "failure":
raise AirflowException(event.get("message"))

results = event.get("results")

if results:
map_index = context["ti"].map_index
offset = (
context["ti"].xcom_pull(
key="offset",
task_ids=self.task_id,
dag_id=self.dag_id,
map_indexes=map_index,
default=0,
)
+ self.page_size
)

self.log.info("Offset increased to %d", offset)
self.xcom_push(context=context, key="offset", value=offset)

self.log.info("Inserting %d rows into %s", len(results), self.destination_conn_id)
self.destination_hook.insert_rows(
table=self.destination_table, rows=results, **self.insert_args
)
self.log.info(
"Inserting %d rows into %s done!",
len(results),
self.destination_conn_id,
)

self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
hook_params=self.source_hook_params,
sql=self.get_paginated_sql(offset),
),
method_name=self.execute_complete.__name__,
)
else:
self.log.info(
"No more rows to fetch into %s; ending transfer.",
self.destination_table,
)
6 changes: 6 additions & 0 deletions providers/src/airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ operators:
- integration-name: Common SQL
python-modules:
- airflow.providers.common.sql.operators.sql
- airflow.providers.common.sql.operators.generic_transfer

dialects:
- dialect-type: default
Expand All @@ -102,6 +103,11 @@ hooks:
- airflow.providers.common.sql.hooks.handlers
- airflow.providers.common.sql.hooks.sql

triggers:
- integration-name: Common SQL
python-modules:
- airflow.providers.common.sql.triggers.sql

sensors:
- integration-name: Common SQL
python-modules:
Expand Down
16 changes: 16 additions & 0 deletions providers/src/airflow/providers/common/sql/triggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
87 changes: 87 additions & 0 deletions providers/src/airflow/providers/common/sql/triggers/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from collections.abc import AsyncIterator
from typing import Any


class SQLExecuteQueryTrigger(BaseTrigger):
"""
A trigger that executes SQL code in async mode.

:param sql: the sql statement to be executed (str) or a list of sql statements to execute
:param conn_id: the connection ID used to connect to the database
:param hook_params: hook parameters
"""

def __init__(
self,
sql: str | list[str],
conn_id: str,
hook_params: dict | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.sql = sql
self.conn_id = conn_id
self.hook_params = hook_params

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the SQLExecuteQueryTrigger arguments and classpath."""
return (
f"{self.__class__.__module__}.{self.__class__.__name__}",
{
"sql": self.sql,
"conn_id": self.conn_id,
"hook_params": self.hook_params,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
try:
hook = BaseHook.get_hook(self.conn_id, hook_params=self.hook_params)

if not isinstance(hook, DbApiHook):
raise AirflowException(
f"You are trying to use `common-sql` with {hook.__class__.__name__},"
" but its provider does not support it. Please upgrade the provider to a version that"
" supports `common-sql`. The hook class should be a subclass of"
" `airflow.providers.common.sql.hooks.sql.DbApiHook`."
f" Got {hook.__class__.__name__} Hook with class hierarchy: {hook.__class__.mro()}"
)

self.log.info("Extracting data from %s", self.conn_id)
self.log.info("Executing: \n %s", self.sql)
self.log.info("Reading records from %s", self.conn_id)
results = hook.get_records(self.sql)
self.log.info("Reading records from %s done!", self.conn_id)

self.log.debug("results: %s", results)
yield TriggerEvent({"status": "success", "results": results})
except Exception as e:
self.log.exception("An error occurred: %s", e)
yield TriggerEvent({"status": "failure", "message": str(e)})
47 changes: 47 additions & 0 deletions providers/src/airflow/providers/common/sql/triggers/sql.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# This is automatically generated stub for the `common.sql` provider
#
# This file is generated automatically by the `update-common-sql-api stubs` pre-commit
# and the .pyi file represents part of the "public" API that the
# `common.sql` provider exposes to other providers.
#
# Any, potentially breaking change in the stubs will require deliberate manual action from the contributor
# making a change to the `common.sql` provider. Those stubs are also used by MyPy automatically when checking
# if only public API of the common.sql provider is used by all the other providers.
#
# You can read more in the README_API.md file
#
"""
Definition of the public interface for airflow.providers.common.sql.triggers.sql
isort:skip_file
"""
from airflow.triggers.base import BaseTrigger, TriggerEvent as TriggerEvent

from collections.abc import AsyncIterator
from typing import Any


class SQLExecuteQueryTrigger(BaseTrigger):
def __init__(
self, sql: str | list[str], conn_id: str, hook_params: dict | None = None, **kwargs,
) -> None: ...

def serialize(self) -> tuple[str, dict[str, Any]]: ...

async def run(self) -> AsyncIterator[TriggerEvent]: ...
Loading
Loading