diff --git a/docs/apache-airflow-providers-openlineage/guides/user.rst b/docs/apache-airflow-providers-openlineage/guides/user.rst index c4a12a7962960..64613f8fdc936 100644 --- a/docs/apache-airflow-providers-openlineage/guides/user.rst +++ b/docs/apache-airflow-providers-openlineage/guides/user.rst @@ -451,6 +451,30 @@ You can enable this automation by setting ``spark_inject_parent_job_info`` optio AIRFLOW__OPENLINEAGE__SPARK_INJECT_PARENT_JOB_INFO=true +Passing transport information to Spark jobs +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +OpenLineage integration can automatically inject Airflow's transport information into Spark application properties, +for :ref:`supported Operators `. +It allows Spark integration to send events to the same backend as Airflow integration without manual configuration. +See `Scheduling from Airflow `_. + +.. note:: + + If any of the ``spark.openlineage.transport*`` properties are manually specified in the Spark job configuration, the integration will refrain from injecting transport properties to ensure that manually provided values are preserved. + +You can enable this automation by setting ``spark_inject_transport_info`` option to ``true`` in Airflow configuration. + +.. code-block:: ini + [openlineage] + transport = {"type": "http", "url": "http://example.com:5000", "endpoint": "api/v1/lineage"} + spark_inject_transport_info = true + +``AIRFLOW__OPENLINEAGE__SPARK_INJECT_TRANSPORT_INFO`` environment variable is an equivalent. + +.. code-block:: ini + AIRFLOW__OPENLINEAGE__SPARK_INJECT_TRANSPORT_INFO=true + Troubleshooting =============== diff --git a/docs/exts/templates/openlineage.rst.jinja2 b/docs/exts/templates/openlineage.rst.jinja2 index af5798d5d51a9..7be5eb560cf21 100644 --- a/docs/exts/templates/openlineage.rst.jinja2 +++ b/docs/exts/templates/openlineage.rst.jinja2 @@ -34,12 +34,15 @@ See :ref:`automatic injection of parent job information =23.0.0", - "apache-airflow-providers-common-compat>=1.3.0", + "apache-airflow-providers-common-compat>=1.4.0", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.9.0", "asgiref>=3.5.2", @@ -974,7 +974,7 @@ }, "openlineage": { "deps": [ - "apache-airflow-providers-common-compat>=1.3.0", + "apache-airflow-providers-common-compat>=1.4.0", "apache-airflow-providers-common-sql>=1.20.0", "apache-airflow>=2.9.0", "attrs>=22.2", diff --git a/providers/src/airflow/providers/common/compat/__init__.py b/providers/src/airflow/providers/common/compat/__init__.py index 21133a52bb083..bee2112ac7343 100644 --- a/providers/src/airflow/providers/common/compat/__init__.py +++ b/providers/src/airflow/providers/common/compat/__init__.py @@ -29,7 +29,7 @@ __all__ = ["__version__"] -__version__ = "1.3.0" +__version__ = "1.4.0" if packaging.version.parse(packaging.version.parse(airflow_version).base_version) < packaging.version.parse( "2.9.0" diff --git a/providers/src/airflow/providers/common/compat/openlineage/utils/spark.py b/providers/src/airflow/providers/common/compat/openlineage/utils/spark.py index 48fef1f1e0124..cbc7997dd44cb 100644 --- a/providers/src/airflow/providers/common/compat/openlineage/utils/spark.py +++ b/providers/src/airflow/providers/common/compat/openlineage/utils/spark.py @@ -23,11 +23,15 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: - from airflow.providers.openlineage.utils.spark import inject_parent_job_information_into_spark_properties + from airflow.providers.openlineage.utils.spark import ( + inject_parent_job_information_into_spark_properties, + inject_transport_information_into_spark_properties, + ) else: try: from airflow.providers.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, + inject_transport_information_into_spark_properties, ) except ImportError: try: @@ -64,5 +68,63 @@ def inject_parent_job_information_into_spark_properties(properties: dict, contex } return {**properties, **ol_parent_job_properties} + try: + from airflow.providers.openlineage.plugins.listener import get_openlineage_listener + except ImportError: + + def inject_transport_information_into_spark_properties(properties: dict, context) -> dict: + log.warning( + "Could not import `airflow.providers.openlineage.plugins.listener`." + "Skipping the injection of OpenLineage transport information into Spark properties." + ) + return properties + + else: + + def inject_transport_information_into_spark_properties(properties: dict, context) -> dict: + if any(str(key).startswith("spark.openlineage.transport") for key in properties): + log.info( + "Some OpenLineage properties with transport information are already present " + "in Spark properties. Skipping the injection of OpenLineage " + "transport information into Spark properties." + ) + return properties + + transport = get_openlineage_listener().adapter.get_or_create_openlineage_client().transport + if transport.kind != "http": + log.info( + "OpenLineage transport type `%s` does not support automatic " + "injection of OpenLineage transport information into Spark properties.", + transport.kind, + ) + return {} + + transport_properties = { + "spark.openlineage.transport.type": "http", + "spark.openlineage.transport.url": transport.url, + "spark.openlineage.transport.endpoint": transport.endpoint, + # Timeout is converted to milliseconds, as required by Spark integration, + "spark.openlineage.transport.timeoutInMillis": str(int(transport.timeout * 1000)), + } + if transport.compression: + transport_properties["spark.openlineage.transport.compression"] = str( + transport.compression + ) + + if hasattr(transport.config.auth, "api_key") and transport.config.auth.get_bearer(): + transport_properties["spark.openlineage.transport.auth.type"] = "api_key" + transport_properties["spark.openlineage.transport.auth.apiKey"] = ( + transport.config.auth.get_bearer() + ) + + if hasattr(transport.config, "custom_headers") and transport.config.custom_headers: + for key, value in transport.config.custom_headers.items(): + transport_properties[f"spark.openlineage.transport.headers.{key}"] = value + + return {**properties, **transport_properties} + -__all__ = ["inject_parent_job_information_into_spark_properties"] +__all__ = [ + "inject_parent_job_information_into_spark_properties", + "inject_transport_information_into_spark_properties", +] diff --git a/providers/src/airflow/providers/common/compat/provider.yaml b/providers/src/airflow/providers/common/compat/provider.yaml index 34be19b27b665..2a2f96af1fef0 100644 --- a/providers/src/airflow/providers/common/compat/provider.yaml +++ b/providers/src/airflow/providers/common/compat/provider.yaml @@ -25,6 +25,7 @@ state: ready source-date-epoch: 1731569875 # note that those versions are maintained by release manager - do not update them manually versions: + - 1.4.0 - 1.3.0 - 1.2.2 - 1.2.1 diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 2e00ca327b65b..3034365c67333 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -43,6 +43,7 @@ ) from airflow.providers.common.compat.openlineage.utils.spark import ( inject_parent_job_information_into_spark_properties, + inject_transport_information_into_spark_properties, ) from airflow.providers.google import __version__ as provider_version from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url @@ -453,7 +454,7 @@ def _replace_dataproc_job_properties(job: dict, job_type: str, new_properties: d def inject_openlineage_properties_into_dataproc_job( - job: dict, context: Context, inject_parent_job_info: bool + job: dict, context: Context, inject_parent_job_info: bool, inject_transport_info: bool ) -> dict: """ Inject OpenLineage properties into Spark job definition. @@ -466,18 +467,19 @@ def inject_openlineage_properties_into_dataproc_job( - OpenLineage provider is not accessible. - The job type is not supported. - Automatic parent job information injection is disabled. - - Any OpenLineage properties with parent job information are already present + - Any OpenLineage properties with respective information are already present in the Spark job definition. Args: job: The original Dataproc job definition. context: The Airflow context in which the job is running. inject_parent_job_info: Flag indicating whether to inject parent job information. + inject_transport_info: Flag indicating whether to inject transport information. Returns: The modified job definition with OpenLineage properties injected, if applicable. """ - if not inject_parent_job_info: + if not inject_parent_job_info and not inject_transport_info: log.debug("Automatic injection of OpenLineage information is disabled.") return job @@ -497,7 +499,17 @@ def inject_openlineage_properties_into_dataproc_job( properties = job[job_type].get("properties", {}) - properties = inject_parent_job_information_into_spark_properties(properties=properties, context=context) + if inject_parent_job_info: + log.debug("Injecting OpenLineage parent job information into Spark properties.") + properties = inject_parent_job_information_into_spark_properties( + properties=properties, context=context + ) + + if inject_transport_info: + log.debug("Injecting OpenLineage transport information into Spark properties.") + properties = inject_transport_information_into_spark_properties( + properties=properties, context=context + ) job_with_ol_config = _replace_dataproc_job_properties( job=job, job_type=job_type, new_properties=properties @@ -587,7 +599,7 @@ def _replace_dataproc_batch_properties(batch: dict | Batch, new_properties: dict def inject_openlineage_properties_into_dataproc_batch( - batch: dict | Batch, context: Context, inject_parent_job_info: bool + batch: dict | Batch, context: Context, inject_parent_job_info: bool, inject_transport_info: bool ) -> dict | Batch: """ Inject OpenLineage properties into Dataproc batch definition. @@ -600,18 +612,19 @@ def inject_openlineage_properties_into_dataproc_batch( - OpenLineage provider is not accessible. - The batch type is not supported. - Automatic parent job information injection is disabled. - - Any OpenLineage properties with parent job information are already present + - Any OpenLineage properties with respective information are already present in the Spark job configuration. Args: batch: The original Dataproc batch definition. context: The Airflow context in which the job is running. inject_parent_job_info: Flag indicating whether to inject parent job information. + inject_transport_info: Flag indicating whether to inject transport information. Returns: The modified batch definition with OpenLineage properties injected, if applicable. """ - if not inject_parent_job_info: + if not inject_parent_job_info and not inject_transport_info: log.debug("Automatic injection of OpenLineage information is disabled.") return batch @@ -631,14 +644,24 @@ def inject_openlineage_properties_into_dataproc_batch( properties = _extract_dataproc_batch_properties(batch) - properties = inject_parent_job_information_into_spark_properties(properties=properties, context=context) + if inject_parent_job_info: + log.debug("Injecting OpenLineage parent job information into Spark properties.") + properties = inject_parent_job_information_into_spark_properties( + properties=properties, context=context + ) + + if inject_transport_info: + log.debug("Injecting OpenLineage transport information into Spark properties.") + properties = inject_transport_information_into_spark_properties( + properties=properties, context=context + ) batch_with_ol_config = _replace_dataproc_batch_properties(batch=batch, new_properties=properties) return batch_with_ol_config def inject_openlineage_properties_into_dataproc_workflow_template( - template: dict, context: Context, inject_parent_job_info: bool + template: dict, context: Context, inject_parent_job_info: bool, inject_transport_info: bool ) -> dict: """ Inject OpenLineage properties into Spark jobs in Workflow Template. @@ -658,11 +681,12 @@ def inject_openlineage_properties_into_dataproc_workflow_template( template: The original Dataproc Workflow Template definition. context: The Airflow context in which the job is running. inject_parent_job_info: Flag indicating whether to inject parent job information. + inject_transport_info: Flag indicating whether to inject transport information. Returns: The modified Workflow Template definition with OpenLineage properties injected, if applicable. """ - if not inject_parent_job_info: + if not inject_parent_job_info and not inject_transport_info: log.debug("Automatic injection of OpenLineage information is disabled.") return template @@ -688,9 +712,17 @@ def inject_openlineage_properties_into_dataproc_workflow_template( properties = single_job_definition[job_type].get("properties", {}) - properties = inject_parent_job_information_into_spark_properties( - properties=properties, context=context - ) + if inject_parent_job_info: + log.debug("Injecting OpenLineage parent job information into Spark properties.") + properties = inject_parent_job_information_into_spark_properties( + properties=properties, context=context + ) + + if inject_transport_info: + log.debug("Injecting OpenLineage transport information into Spark properties.") + properties = inject_transport_information_into_spark_properties( + properties=properties, context=context + ) job_with_ol_config = _replace_dataproc_job_properties( job=single_job_definition, job_type=job_type, new_properties=properties diff --git a/providers/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/src/airflow/providers/google/cloud/operators/dataproc.py index 1d5ced10283c9..71c50abe5bba3 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataproc.py @@ -1829,6 +1829,9 @@ def __init__( openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), + openlineage_inject_transport_info: bool = conf.getboolean( + "openlineage", "spark_inject_transport_info", fallback=False + ), **kwargs, ) -> None: super().__init__(**kwargs) @@ -1849,17 +1852,19 @@ def __init__( self.cancel_on_kill = cancel_on_kill self.operation_name: str | None = None self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info + self.openlineage_inject_transport_info = openlineage_inject_transport_info def execute(self, context: Context): self.log.info("Instantiating Inline Template") hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) project_id = self.project_id or hook.project_id - if self.openlineage_inject_parent_job_info: + if self.openlineage_inject_parent_job_info or self.openlineage_inject_transport_info: self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.") self.template = inject_openlineage_properties_into_dataproc_workflow_template( template=self.template, context=context, inject_parent_job_info=self.openlineage_inject_parent_job_info, + inject_transport_info=self.openlineage_inject_transport_info, ) operation = hook.instantiate_inline_workflow_template( @@ -1982,6 +1987,9 @@ def __init__( openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), + openlineage_inject_transport_info: bool = conf.getboolean( + "openlineage", "spark_inject_transport_info", fallback=False + ), **kwargs, ) -> None: super().__init__(**kwargs) @@ -2004,14 +2012,18 @@ def __init__( self.job_id: str | None = None self.wait_timeout = wait_timeout self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info + self.openlineage_inject_transport_info = openlineage_inject_transport_info def execute(self, context: Context): self.log.info("Submitting job") self.hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - if self.openlineage_inject_parent_job_info: + if self.openlineage_inject_parent_job_info or self.openlineage_inject_transport_info: self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.") self.job = inject_openlineage_properties_into_dataproc_job( - job=self.job, context=context, inject_parent_job_info=self.openlineage_inject_parent_job_info + job=self.job, + context=context, + inject_parent_job_info=self.openlineage_inject_parent_job_info, + inject_transport_info=self.openlineage_inject_transport_info, ) job_object = self.hook.submit_job( project_id=self.project_id, @@ -2442,6 +2454,9 @@ def __init__( openlineage_inject_parent_job_info: bool = conf.getboolean( "openlineage", "spark_inject_parent_job_info", fallback=False ), + openlineage_inject_transport_info: bool = conf.getboolean( + "openlineage", "spark_inject_transport_info", fallback=False + ), **kwargs, ): super().__init__(**kwargs) @@ -2464,6 +2479,7 @@ def __init__( self.deferrable = deferrable self.polling_interval_seconds = polling_interval_seconds self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info + self.openlineage_inject_transport_info = openlineage_inject_transport_info def execute(self, context: Context): if self.asynchronous and self.deferrable: @@ -2486,12 +2502,13 @@ def execute(self, context: Context): else: self.log.info("Starting batch. The batch ID will be generated since it was not provided.") - if self.openlineage_inject_parent_job_info: + if self.openlineage_inject_parent_job_info or self.openlineage_inject_transport_info: self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.") self.batch = inject_openlineage_properties_into_dataproc_batch( batch=self.batch, context=context, inject_parent_job_info=self.openlineage_inject_parent_job_info, + inject_transport_info=self.openlineage_inject_transport_info, ) try: diff --git a/providers/src/airflow/providers/google/provider.yaml b/providers/src/airflow/providers/google/provider.yaml index b253967472d53..c67c8432f4cec 100644 --- a/providers/src/airflow/providers/google/provider.yaml +++ b/providers/src/airflow/providers/google/provider.yaml @@ -101,7 +101,7 @@ versions: dependencies: - apache-airflow>=2.9.0 - - apache-airflow-providers-common-compat>=1.3.0 + - apache-airflow-providers-common-compat>=1.4.0 - apache-airflow-providers-common-sql>=1.20.0 - asgiref>=3.5.2 - dill>=0.2.3 diff --git a/providers/src/airflow/providers/openlineage/conf.py b/providers/src/airflow/providers/openlineage/conf.py index 53a4be746f697..c7300d298e245 100644 --- a/providers/src/airflow/providers/openlineage/conf.py +++ b/providers/src/airflow/providers/openlineage/conf.py @@ -83,6 +83,12 @@ def spark_inject_parent_job_info() -> bool: return conf.getboolean(_CONFIG_SECTION, "spark_inject_parent_job_info", fallback="False") +@cache +def spark_inject_transport_info() -> bool: + """[openlineage] spark_inject_transport_info.""" + return conf.getboolean(_CONFIG_SECTION, "spark_inject_transport_info", fallback="False") + + @cache def custom_extractors() -> set[str]: """[openlineage] extractors.""" diff --git a/providers/src/airflow/providers/openlineage/provider.yaml b/providers/src/airflow/providers/openlineage/provider.yaml index 71115b099d47f..60b71fe4651cf 100644 --- a/providers/src/airflow/providers/openlineage/provider.yaml +++ b/providers/src/airflow/providers/openlineage/provider.yaml @@ -54,7 +54,7 @@ versions: dependencies: - apache-airflow>=2.9.0 - apache-airflow-providers-common-sql>=1.20.0 - - apache-airflow-providers-common-compat>=1.3.0 + - apache-airflow-providers-common-compat>=1.4.0 - attrs>=22.2 - openlineage-integration-common>=1.24.2 - openlineage-python>=1.24.2 @@ -192,4 +192,12 @@ config: type: boolean default: "False" example: ~ - version_added: 1.15.0 + version_added: 2.0.0 + spark_inject_transport_info: + description: | + Automatically inject OpenLineage's transport information into Spark application properties + for supported Operators. + type: boolean + default: "False" + example: ~ + version_added: 2.1.0 diff --git a/providers/src/airflow/providers/openlineage/utils/spark.py b/providers/src/airflow/providers/openlineage/utils/spark.py index 80395c0919965..c2991e651804c 100644 --- a/providers/src/airflow/providers/openlineage/utils/spark.py +++ b/providers/src/airflow/providers/openlineage/utils/spark.py @@ -20,6 +20,7 @@ import logging from typing import TYPE_CHECKING +from airflow.providers.openlineage.plugins.listener import get_openlineage_listener from airflow.providers.openlineage.plugins.macros import ( lineage_job_name, lineage_job_namespace, @@ -50,6 +51,39 @@ def _get_parent_job_information_as_spark_properties(context: Context) -> dict: } +def _get_transport_information_as_spark_properties() -> dict: + """Retrieve transport information as Spark properties.""" + transport = get_openlineage_listener().adapter.get_or_create_openlineage_client().transport + if transport.kind != "http": + log.info( + "OpenLineage transport type `%s` does not support automatic " + "injection of OpenLineage transport information into Spark properties.", + transport.kind, + ) + return {} + + properties = { + "spark.openlineage.transport.type": transport.kind, + "spark.openlineage.transport.url": transport.url, + "spark.openlineage.transport.endpoint": transport.endpoint, + "spark.openlineage.transport.timeoutInMillis": str( + int(transport.timeout * 1000) # convert to milliseconds, as required by Spark integration + ), + } + if transport.compression: + properties["spark.openlineage.transport.compression"] = str(transport.compression) + + if hasattr(transport.config.auth, "api_key") and transport.config.auth.get_bearer(): + properties["spark.openlineage.transport.auth.type"] = "api_key" + properties["spark.openlineage.transport.auth.apiKey"] = transport.config.auth.get_bearer() + + if hasattr(transport.config, "custom_headers") and transport.config.custom_headers: + for key, value in transport.config.custom_headers.items(): + properties[f"spark.openlineage.transport.headers.{key}"] = value + + return properties + + def _is_parent_job_information_present_in_spark_properties(properties: dict) -> bool: """ Check if any parent job information is present in Spark properties. @@ -63,6 +97,19 @@ def _is_parent_job_information_present_in_spark_properties(properties: dict) -> return any(str(key).startswith("spark.openlineage.parent") for key in properties) +def _is_transport_information_present_in_spark_properties(properties: dict) -> bool: + """ + Check if any transport information is present in Spark properties. + + Args: + properties: Spark properties. + + Returns: + True if transport information is present, False otherwise. + """ + return any(str(key).startswith("spark.openlineage.transport") for key in properties) + + def inject_parent_job_information_into_spark_properties(properties: dict, context: Context) -> dict: """ Inject parent job information into Spark properties if not already present. @@ -82,5 +129,26 @@ def inject_parent_job_information_into_spark_properties(properties: dict, contex ) return properties - ol_parent_job_properties = _get_parent_job_information_as_spark_properties(context) - return {**properties, **ol_parent_job_properties} + return {**properties, **_get_parent_job_information_as_spark_properties(context)} + + +def inject_transport_information_into_spark_properties(properties: dict, context: Context) -> dict: + """ + Inject transport information into Spark properties if not already present. + + Args: + properties: Spark properties. + context: The context containing task instance information. + + Returns: + Modified Spark properties with OpenLineage transport information properties injected, if applicable. + """ + if _is_transport_information_present_in_spark_properties(properties): + log.info( + "Some OpenLineage properties with transport information are already present " + "in Spark properties. Skipping the injection of OpenLineage " + "transport information into Spark properties." + ) + return properties + + return {**properties, **_get_transport_information_as_spark_properties()} diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index 8fa9c90e0e71f..12a640f7a3035 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -24,6 +24,7 @@ from google.cloud.bigquery.table import Table from google.cloud.dataproc_v1 import Batch, RuntimeConfig from openlineage.client.facet_v2 import column_lineage_dataset +from openlineage.client.transport import HttpConfig, HttpTransport from airflow.providers.common.compat.openlineage.facet import ( ColumnLineageDatasetFacet, @@ -75,6 +76,72 @@ "tableReference": {"projectId": TEST_PROJECT_ID, "datasetId": TEST_DATASET, "tableId": TEST_TABLE_ID} } TEST_EMPTY_TABLE: Table = Table.from_api_repr(TEST_EMPTY_TABLE_API_REPR) +EXAMPLE_BATCH = { + "spark_batch": { + "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], + "main_class": "org.apache.spark.examples.SparkPi", + }, + "runtime_config": {"properties": {"existingProperty": "value"}}, +} +EXAMPLE_TEMPLATE = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + } + ], +} +EXAMPLE_CONTEXT = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=1, + logical_date=dt.datetime(2024, 11, 11), + ) +} +OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG = { + "url": "https://some-custom.url", + "endpoint": "/api/custom", + "timeout": 123, + "compression": "gzip", + "custom_headers": { + "key1": "val1", + "key2": "val2", + }, + "auth": { + "type": "api_key", + "apiKey": "secret_123", + }, +} +OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES = { + "spark.openlineage.transport.type": "http", + "spark.openlineage.transport.url": "https://some-custom.url", + "spark.openlineage.transport.endpoint": "/api/custom", + "spark.openlineage.transport.auth.type": "api_key", + "spark.openlineage.transport.auth.apiKey": "Bearer secret_123", + "spark.openlineage.transport.compression": "gzip", + "spark.openlineage.transport.headers.key1": "val1", + "spark.openlineage.transport.headers.key2": "val2", + "spark.openlineage.transport.timeoutInMillis": "123000", +} +OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES = { + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", +} def read_file_json(file): @@ -507,7 +574,7 @@ def test_replace_dataproc_job_properties_key_error(): def test_inject_openlineage_properties_into_dataproc_job_provider_not_accessible(mock_is_accessible): mock_is_accessible.return_value = False job = {"sparkJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, None, True) + result = inject_openlineage_properties_into_dataproc_job(job, None, True, True) assert result == job @@ -519,43 +586,69 @@ def test_inject_openlineage_properties_into_dataproc_job_unsupported_job_type( mock_is_accessible.return_value = True mock_extract_job_type.return_value = None job = {"unsupportedJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, None, True) + result = inject_openlineage_properties_into_dataproc_job(job, None, True, True) assert result == job @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @patch("airflow.providers.google.cloud.openlineage.utils._extract_supported_job_type_from_dataproc_job") -def test_inject_openlineage_properties_into_dataproc_job_no_inject_parent_job_info( +def test_inject_openlineage_properties_into_dataproc_job_no_injection( mock_extract_job_type, mock_is_accessible ): mock_is_accessible.return_value = True mock_extract_job_type.return_value = "sparkJob" inject_parent_job_info = False job = {"sparkJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, None, inject_parent_job_info) + result = inject_openlineage_properties_into_dataproc_job(job, None, inject_parent_job_info, False) assert result == job @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") -def test_inject_openlineage_properties_into_dataproc_job(mock_is_ol_accessible): +def test_inject_openlineage_properties_into_dataproc_job_parent_info_only(mock_is_ol_accessible): mock_is_ol_accessible.return_value = True - context = { - "ti": MagicMock( - dag_id="dag_id", - task_id="task_id", - try_number=1, - map_index=1, - logical_date=dt.datetime(2024, 11, 11), - ) + expected_properties = { + "existingProperty": "value", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, } + job = {"sparkJob": {"properties": {"existingProperty": "value"}}} + result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, True, False) + assert result == {"sparkJob": {"properties": expected_properties}} + + +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_job_transport_info_only( + mock_is_ol_accessible, mock_ol_listener +): + mock_is_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) expected_properties = { "existingProperty": "value", - "spark.openlineage.parentJobName": "dag_id.task_id", - "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, } job = {"sparkJob": {"properties": {"existingProperty": "value"}}} - result = inject_openlineage_properties_into_dataproc_job(job, context, True) + result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, False, True) + assert result == {"sparkJob": {"properties": expected_properties}} + + +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_job_all_injections( + mock_is_ol_accessible, mock_ol_listener +): + mock_is_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + expected_properties = { + "existingProperty": "value", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + } + job = {"sparkJob": {"properties": {"existingProperty": "value"}}} + result = inject_openlineage_properties_into_dataproc_job(job, EXAMPLE_CONTEXT, True, True) assert result == {"sparkJob": {"properties": expected_properties}} @@ -577,7 +670,7 @@ def test_is_dataproc_batch_of_supported_type(batch, expected): assert _is_dataproc_batch_of_supported_type(batch) == expected -def test__extract_dataproc_batch_properties_batch_object_with_runtime_object(): +def test_extract_dataproc_batch_properties_batch_object_with_runtime_object(): properties = {"key1": "value1", "key2": "value2"} mock_runtime_config = RuntimeConfig(properties=properties) mock_batch = Batch(runtime_config=mock_runtime_config) @@ -749,15 +842,8 @@ def test_replace_dataproc_batch_properties_with_empty_dict(): @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") def test_inject_openlineage_properties_into_dataproc_batch_provider_not_accessible(mock_is_accessible): mock_is_accessible.return_value = False - batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, - "runtime_config": {"properties": {"existingProperty": "value"}}, - } - result = inject_openlineage_properties_into_dataproc_batch(batch, None, True) - assert result == batch + result = inject_openlineage_properties_into_dataproc_batch(EXAMPLE_BATCH, None, True, True) + assert result == EXAMPLE_BATCH @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @@ -774,62 +860,73 @@ def test_inject_openlineage_properties_into_dataproc_batch_unsupported_batch_typ }, "runtime_config": {"properties": {"existingProperty": "value"}}, } - result = inject_openlineage_properties_into_dataproc_batch(batch, None, True) + result = inject_openlineage_properties_into_dataproc_batch(batch, None, True, True) assert result == batch @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @patch("airflow.providers.google.cloud.openlineage.utils._is_dataproc_batch_of_supported_type") -def test_inject_openlineage_properties_into_dataproc_batch_no_inject_parent_job_info( +def test_inject_openlineage_properties_into_dataproc_batch_no_injection( mock_valid_job_type, mock_is_accessible ): mock_is_accessible.return_value = True mock_valid_job_type.return_value = True - inject_parent_job_info = False - batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, - "runtime_config": {"properties": {"existingProperty": "value"}}, - } - result = inject_openlineage_properties_into_dataproc_batch(batch, None, inject_parent_job_info) - assert result == batch + result = inject_openlineage_properties_into_dataproc_batch(EXAMPLE_BATCH, None, False, False) + assert result == EXAMPLE_BATCH @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") -def test_inject_openlineage_properties_into_dataproc_batch(mock_is_ol_accessible): +def test_inject_openlineage_properties_into_dataproc_batch_parent_info_only(mock_is_ol_accessible): mock_is_ol_accessible.return_value = True - context = { - "ti": MagicMock( - dag_id="dag_id", - task_id="task_id", - try_number=1, - map_index=1, - logical_date=dt.datetime(2024, 11, 11), - ) + expected_properties = { + "existingProperty": "value", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, } - batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, - "runtime_config": {"properties": {"existingProperty": "value"}}, + expected_batch = { + **EXAMPLE_BATCH, + "runtime_config": {"properties": expected_properties}, } + result = inject_openlineage_properties_into_dataproc_batch(EXAMPLE_BATCH, EXAMPLE_CONTEXT, True, False) + assert result == expected_batch + + +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_batch_transport_info_only( + mock_is_ol_accessible, mock_ol_listener +): + mock_is_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + expected_properties = {"existingProperty": "value", **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES} + expected_batch = { + **EXAMPLE_BATCH, + "runtime_config": {"properties": expected_properties}, + } + result = inject_openlineage_properties_into_dataproc_batch(EXAMPLE_BATCH, EXAMPLE_CONTEXT, False, True) + assert result == expected_batch + + +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_batch_all_injections( + mock_is_ol_accessible, mock_ol_listener +): + mock_is_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) expected_properties = { "existingProperty": "value", - "spark.openlineage.parentJobName": "dag_id.task_id", - "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, } expected_batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, + **EXAMPLE_BATCH, "runtime_config": {"properties": expected_properties}, } - result = inject_openlineage_properties_into_dataproc_batch(batch, context, True) + result = inject_openlineage_properties_into_dataproc_batch(EXAMPLE_BATCH, EXAMPLE_CONTEXT, True, True) assert result == expected_batch @@ -867,45 +964,195 @@ def test_inject_openlineage_properties_into_dataproc_workflow_template_provider_ ): mock_is_accessible.return_value = False template = {"workflow": "template"} # It does not matter what the dict is, we should return it unmodified - result = inject_openlineage_properties_into_dataproc_workflow_template(template, None, True) + result = inject_openlineage_properties_into_dataproc_workflow_template(template, None, True, True) assert result == template @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @patch("airflow.providers.google.cloud.openlineage.utils._extract_supported_job_type_from_dataproc_job") -def test_inject_openlineage_properties_into_dataproc_workflow_template_no_inject_parent_job_info( +def test_inject_openlineage_properties_into_dataproc_workflow_template_no_injection( mock_extract_job_type, mock_is_accessible ): mock_is_accessible.return_value = True mock_extract_job_type.return_value = "sparkJob" - inject_parent_job_info = False template = {"workflow": "template"} # It does not matter what the dict is, we should return it unmodified + result = inject_openlineage_properties_into_dataproc_workflow_template(template, None, False, False) + assert result == template + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_workflow_template_parent_info_only( + mock_is_ol_accessible, +): + mock_is_ol_accessible.return_value = True + template = { + **EXAMPLE_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } + expected_template = { + **EXAMPLE_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { # Injected properties + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { # Not modified because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { # Not modified because it's unsupported job type + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } result = inject_openlineage_properties_into_dataproc_workflow_template( - template, None, inject_parent_job_info + template, EXAMPLE_CONTEXT, True, False ) - assert result == template + assert result == expected_template +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") @patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") -def test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_ol_accessible): +def test_inject_openlineage_properties_into_dataproc_workflow_template_transport_info_only( + mock_is_ol_accessible, mock_ol_listener +): mock_is_ol_accessible.return_value = True - context = { - "ti": MagicMock( - dag_id="dag_id", - task_id="task_id", - try_number=1, - map_index=1, - logical_date=dt.datetime(2024, 11, 11), - ) + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + template = { + **EXAMPLE_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } + expected_template = { + **EXAMPLE_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { # Injected properties + "spark.sql.shuffle.partitions": "1", + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { # Not modified because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { # Not modified because it's unsupported job type + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], } + result = inject_openlineage_properties_into_dataproc_workflow_template( + template, EXAMPLE_CONTEXT, False, True + ) + assert result == expected_template + + +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_workflow_template_all_injections( + mock_is_ol_accessible, mock_ol_listener +): + mock_is_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) template = { - "id": "test-workflow", - "placement": { - "cluster_selector": { - "zone": "europe-central2-c", - "cluster_labels": {"key": "value"}, - } - }, + **EXAMPLE_TEMPLATE, "jobs": [ { "step_id": "job_1", @@ -922,7 +1169,7 @@ def test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_o "main_python_file_uri": "gs://bucket2/spark_job.py", "properties": { "spark.sql.shuffle.partitions": "1", - "spark.openlineage.parentJobNamespace": "test", + "spark.openlineage.transport.type": "console", }, }, }, @@ -935,6 +1182,16 @@ def test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_o }, }, }, + { + "step_id": "job_4", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, ], } expected_template = { @@ -952,19 +1209,19 @@ def test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_o "main_python_file_uri": "gs://bucket1/spark_job.py", "properties": { # Injected properties "spark.sql.shuffle.partitions": "1", - "spark.openlineage.parentJobName": "dag_id.task_id", - "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, }, }, }, { "step_id": "job_2", - "pyspark_job": { # Not modified because it's already present + "pyspark_job": { # Only parent info injected "main_python_file_uri": "gs://bucket2/spark_job.py", "properties": { "spark.sql.shuffle.partitions": "1", - "spark.openlineage.parentJobNamespace": "test", + "spark.openlineage.transport.type": "console", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, }, }, }, @@ -977,7 +1234,20 @@ def test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_o }, }, }, + { + "step_id": "job_4", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { # Only transport info injected + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + }, ], } - result = inject_openlineage_properties_into_dataproc_workflow_template(template, context, True) + result = inject_openlineage_properties_into_dataproc_workflow_template( + template, EXAMPLE_CONTEXT, True, True + ) assert result == expected_template diff --git a/providers/tests/google/cloud/operators/test_dataproc.py b/providers/tests/google/cloud/operators/test_dataproc.py index f79a0bdba0ce9..3acc52b39ea95 100644 --- a/providers/tests/google/cloud/operators/test_dataproc.py +++ b/providers/tests/google/cloud/operators/test_dataproc.py @@ -27,6 +27,7 @@ from google.api_core.retry_async import AsyncRetry from google.cloud import dataproc from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus +from openlineage.client.transport import HttpConfig, HttpTransport, KafkaConfig, KafkaTransport from airflow import __version__ as AIRFLOW_VERSION from airflow.exceptions import ( @@ -399,6 +400,45 @@ "main_class": "org.apache.spark.examples.SparkPi", }, } +EXAMPLE_CONTEXT = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=1, + logical_date=dt.datetime(2024, 11, 11), + ) +} +OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG = { + "url": "https://some-custom.url", + "endpoint": "/api/custom", + "timeout": 123, + "compression": "gzip", + "custom_headers": { + "key1": "val1", + "key2": "val2", + }, + "auth": { + "type": "api_key", + "apiKey": "secret_123", + }, +} +OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES = { + "spark.openlineage.transport.type": "http", + "spark.openlineage.transport.url": "https://some-custom.url", + "spark.openlineage.transport.endpoint": "/api/custom", + "spark.openlineage.transport.auth.type": "api_key", + "spark.openlineage.transport.auth.apiKey": "Bearer secret_123", + "spark.openlineage.transport.compression": "gzip", + "spark.openlineage.transport.headers.key1": "val1", + "spark.openlineage.transport.headers.key2": "val2", + "spark.openlineage.transport.timeoutInMillis": "123000", +} +OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES = { + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", +} def assert_warning(msg: str, warnings): @@ -1571,11 +1611,171 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_ metadata=METADATA, ) + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_http_transport_info_injection( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + job_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + } + expected_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + } + + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + job=job_config, + openlineage_inject_transport_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + + mock_hook.return_value.submit_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + job=expected_config, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_all_info_injection(self, mock_hook, mock_ol_accessible, mock_ol_listener): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + job_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + } + expected_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + } + + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + job=job_config, + openlineage_inject_parent_job_info=True, + openlineage_inject_transport_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + + mock_hook.return_value.submit_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + job=expected_config, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_unsupported_transport_info_injection( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + kafka_config = KafkaConfig( + topic="my_topic", + config={ + "bootstrap.servers": "localhost:9092,another.host:9092", + "acks": "all", + "retries": "3", + }, + flush=True, + messageKey="some", + ) + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = KafkaTransport( + kafka_config + ) + job_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + } + + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + job=job_config, + openlineage_inject_transport_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + + mock_hook.return_value.submit_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + job=job_config, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_parent_job_info_injection_skipped_when_already_present( self, mock_hook, mock_ol_accessible ): + mock_ol_accessible.return_value = True job_config = { "placement": {"cluster_name": CLUSTER_NAME}, "pyspark_job": { @@ -1587,7 +1787,49 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres }, } + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + job=job_config, + openlineage_inject_parent_job_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + + mock_hook.return_value.submit_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + job=job_config, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_when_already_present( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + job_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + }, + }, + } op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -1598,9 +1840,9 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres timeout=TIMEOUT, metadata=METADATA, job=job_config, - openlineage_inject_parent_job_info=True, + openlineage_inject_transport_info=True, ) - op.execute(context=self.mock_context) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -1617,6 +1859,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless_enabled( self, mock_hook, mock_ol_accessible ): + mock_ol_accessible.return_value = True job_config = { "placement": {"cluster_name": CLUSTER_NAME}, "pyspark_job": { @@ -1627,7 +1870,48 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless }, } + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + job=job_config, + # not passing openlineage_inject_parent_job_info, should be False by default + ) + op.execute(context=EXAMPLE_CONTEXT) + + mock_hook.return_value.submit_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + job=job_config, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_by_default_unless_enabled( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + job_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + } op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -1638,9 +1922,9 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless timeout=TIMEOUT, metadata=METADATA, job=job_config, - # not passing openlineage_inject_parent_job_info, should be False by default + # not passing openlineage_inject_transport_info, should be False by default ) - op.execute(context=self.mock_context) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -1657,6 +1941,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_accessible( self, mock_hook, mock_ol_accessible ): + mock_ol_accessible.return_value = False job_config = { "placement": {"cluster_name": CLUSTER_NAME}, "pyspark_job": { @@ -1667,7 +1952,48 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces }, } + op = DataprocSubmitJobOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + job=job_config, + openlineage_inject_parent_job_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + + mock_hook.return_value.submit_job.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + job=job_config, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_accessible( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): mock_ol_accessible.return_value = False + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + job_config = { + "placement": {"cluster_name": CLUSTER_NAME}, + "pyspark_job": { + "main_python_file_uri": "gs://example/wordcount.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + } op = DataprocSubmitJobOperator( task_id=TASK_ID, @@ -1678,9 +2004,9 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces timeout=TIMEOUT, metadata=METADATA, job=job_config, - openlineage_inject_parent_job_info=True, + openlineage_inject_transport_info=True, ) - op.execute(context=self.mock_context) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.submit_job.assert_called_once_with( project_id=GCP_PROJECT, @@ -2360,23 +2686,8 @@ def test_wait_for_operation_on_execute(self, mock_hook): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_accessible): mock_ol_accessible.return_value = True - context = { - "ti": MagicMock( - dag_id="dag_id", - task_id="task_id", - try_number=1, - map_index=1, - logical_date=dt.datetime(2024, 11, 11), - ) - } template = { - "id": "test-workflow", - "placement": { - "cluster_selector": { - "zone": "europe-central2-c", - "cluster_labels": {"key": "value"}, - } - }, + **WORKFLOW_TEMPLATE, "jobs": [ { "step_id": "job_1", @@ -2407,23 +2718,9 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_ }, }, ], - "parameters": [ - { - "name": "ZONE", - "fields": [ - "placement.clusterSelector.zone", - ], - } - ], } expected_template = { - "id": "test-workflow", - "placement": { - "cluster_selector": { - "zone": "europe-central2-c", - "cluster_labels": {"key": "value"}, - } - }, + **WORKFLOW_TEMPLATE, "jobs": [ { "step_id": "job_1", @@ -2457,14 +2754,6 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_ }, }, ], - "parameters": [ - { - "name": "ZONE", - "fields": [ - "placement.clusterSelector.zone", - ], - } - ], } op = DataprocInstantiateInlineWorkflowTemplateOperator( @@ -2480,7 +2769,7 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_ impersonation_chain=IMPERSONATION_CHAIN, openlineage_inject_parent_job_info=True, ) - op.execute(context=context) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( template=expected_template, @@ -2498,15 +2787,8 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless self, mock_hook, mock_ol_accessible ): mock_ol_accessible.return_value = True - template = { - "id": "test-workflow", - "placement": { - "cluster_selector": { - "zone": "europe-central2-c", - "cluster_labels": {"key": "value"}, - } - }, + **WORKFLOW_TEMPLATE, "jobs": [ { "step_id": "job_1", @@ -2530,7 +2812,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless impersonation_chain=IMPERSONATION_CHAIN, # not passing openlineage_inject_parent_job_info, should be False by default ) - op.execute(context=MagicMock()) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( template=template, @@ -2550,13 +2832,7 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces mock_ol_accessible.return_value = False template = { - "id": "test-workflow", - "placement": { - "cluster_selector": { - "zone": "europe-central2-c", - "cluster_labels": {"key": "value"}, - } - }, + **WORKFLOW_TEMPLATE, "jobs": [ { "step_id": "job_1", @@ -2580,7 +2856,332 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces impersonation_chain=IMPERSONATION_CHAIN, openlineage_inject_parent_job_info=True, ) - op.execute(context=MagicMock()) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + + template = { + **WORKFLOW_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } + expected_template = { + **WORKFLOW_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { # Injected properties + "spark.sql.shuffle.partitions": "1", + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { # Not modified because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { # Not modified because it's unsupported job type + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + openlineage_inject_transport_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=expected_template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_all_info_injection(self, mock_hook, mock_ol_accessible, mock_ol_listener): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + template = { + **WORKFLOW_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_4", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + ], + } + expected_template = { + **WORKFLOW_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { # Injected all properties + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { # Transport not added because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.transport.type": "console", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { # Not modified because it's unsupported job type + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_4", + "pyspark_job": { # Parent not added because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, + }, + }, + }, + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + openlineage_inject_parent_job_info=True, + openlineage_inject_transport_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=expected_template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_by_default_unless_enabled( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig("https://some-custom.url") + ) + + template = { + **WORKFLOW_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + }, + } + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + # not passing openlineage_inject_transport_info, should be False by default + ) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_accessible( + self, mock_hook, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = False + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig("https://some-custom.url") + ) + + template = { + **WORKFLOW_TEMPLATE, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + }, + } + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + openlineage_inject_transport_info=True, + ) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( template=template, @@ -2847,30 +3448,102 @@ def test_execute_batch_already_exists_cancelled(self, mock_hook): @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_mock, mock_ol_accessible): + mock_ol_accessible.return_value = True expected_batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, + **BATCH, + "runtime_config": {"properties": OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES}, + } + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=BATCH, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_parent_job_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=expected_batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection( + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + expected_batch = { + **BATCH, + "runtime_config": {"properties": OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES}, + } + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=BATCH, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_transport_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=expected_batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_all_info_injection( + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + expected_batch = { + **BATCH, "runtime_config": { "properties": { - "spark.openlineage.parentJobName": "dag_id.task_id", - "spark.openlineage.parentJobNamespace": "default", - "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + **OPENLINEAGE_PARENT_JOB_EXAMPLE_SPARK_PROPERTIES, + **OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_SPARK_PROPERTIES, } }, } - context = { - "ti": MagicMock( - dag_id="dag_id", - task_id="task_id", - try_number=1, - map_index=1, - logical_date=dt.datetime(2024, 11, 11), - ) - } - - mock_ol_accessible.return_value = True op = DataprocCreateBatchOperator( task_id=TASK_ID, @@ -2885,9 +3558,10 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_ timeout=TIMEOUT, metadata=METADATA, openlineage_inject_parent_job_info=True, + openlineage_inject_transport_info=True, ) mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) - op.execute(context=context) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, project_id=GCP_PROJECT, @@ -2905,18 +3579,15 @@ def test_execute_openlineage_parent_job_info_injection(self, mock_hook, to_dict_ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_present( self, mock_hook, to_dict_mock, mock_ol_accessible ): + mock_ol_accessible.return_value = True batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, + **BATCH, "runtime_config": { "properties": { "spark.openlineage.parentJobName": "dag_id.task_id", } }, } - mock_ol_accessible.return_value = True op = DataprocCreateBatchOperator( task_id=TASK_ID, @@ -2933,7 +3604,54 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres openlineage_inject_parent_job_info=True, ) mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) - op.execute(context=MagicMock()) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_when_already_present( + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + batch = { + **BATCH, + "runtime_config": { + "properties": { + "spark.openlineage.transport.type": "console", + } + }, + } + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_transport_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, project_id=GCP_PROJECT, @@ -2951,14 +3669,11 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_already_pres def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless_enabled( self, mock_hook, to_dict_mock, mock_ol_accessible ): + mock_ol_accessible.return_value = True batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, + **BATCH, "runtime_config": {"properties": {}}, } - mock_ol_accessible.return_value = True op = DataprocCreateBatchOperator( task_id=TASK_ID, @@ -2975,7 +3690,50 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless # not passing openlineage_inject_parent_job_info, should be False by default ) mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) - op.execute(context=MagicMock()) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_by_default_unless_enabled( + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = True + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + batch = { + **BATCH, + "runtime_config": {"properties": {}}, + } + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + # not passing openlineage_inject_transport_info, should be False by default + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, project_id=GCP_PROJECT, @@ -2993,14 +3751,11 @@ def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_accessible( self, mock_hook, to_dict_mock, mock_ol_accessible ): + mock_ol_accessible.return_value = False batch = { - "spark_batch": { - "jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"], - "main_class": "org.apache.spark.examples.SparkPi", - }, + **BATCH, "runtime_config": {"properties": {}}, } - mock_ol_accessible.return_value = False op = DataprocCreateBatchOperator( task_id=TASK_ID, @@ -3017,7 +3772,50 @@ def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_acces openlineage_inject_parent_job_info=True, ) mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) - op.execute(context=MagicMock()) + op.execute(context=EXAMPLE_CONTEXT) + mock_hook.return_value.create_batch.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("Batch.to_dict")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_transport_info_injection_skipped_when_ol_not_accessible( + self, mock_hook, to_dict_mock, mock_ol_accessible, mock_ol_listener + ): + mock_ol_accessible.return_value = False + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(OPENLINEAGE_HTTP_TRANSPORT_EXAMPLE_CONFIG) + ) + batch = { + **BATCH, + "runtime_config": {"properties": {}}, + } + + op = DataprocCreateBatchOperator( + task_id=TASK_ID, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + region=GCP_REGION, + project_id=GCP_PROJECT, + batch=batch, + batch_id=BATCH_ID, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + openlineage_inject_transport_info=True, + ) + mock_hook.return_value.wait_for_operation.return_value = Batch(state=Batch.State.SUCCEEDED) + op.execute(context=EXAMPLE_CONTEXT) mock_hook.return_value.create_batch.assert_called_once_with( region=GCP_REGION, project_id=GCP_PROJECT, diff --git a/providers/tests/openlineage/test_conf.py b/providers/tests/openlineage/test_conf.py index 507d07735364e..dcf20719f0906 100644 --- a/providers/tests/openlineage/test_conf.py +++ b/providers/tests/openlineage/test_conf.py @@ -37,6 +37,7 @@ namespace, selective_enable, spark_inject_parent_job_info, + spark_inject_transport_info, transport, ) @@ -63,6 +64,7 @@ _CONFIG_OPTION_INCLUDE_FULL_TASK_INFO = "include_full_task_info" _CONFIG_OPTION_DEBUG_MODE = "debug_mode" _CONFIG_OPTION_SPARK_INJECT_PARENT_JOB_INFO = "spark_inject_parent_job_info" +_CONFIG_OPTION_SPARK_INJECT_TRANSPORT_INFO = "spark_inject_transport_info" _BOOL_PARAMS = ( ("1", True), @@ -641,3 +643,30 @@ def test_spark_inject_parent_job_info_empty_conf_option(): @conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SPARK_INJECT_PARENT_JOB_INFO): None}) def test_spark_inject_parent_job_info_do_not_fail_if_conf_option_missing(): assert spark_inject_parent_job_info() is False + + +@pytest.mark.parametrize( + ("var_string", "expected"), + _BOOL_PARAMS, +) +def test_spark_inject_transport_info(var_string, expected): + with conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SPARK_INJECT_TRANSPORT_INFO): var_string}): + result = spark_inject_transport_info() + assert result is expected + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SPARK_INJECT_TRANSPORT_INFO): "asdadawlaksnd"}) +def test_spark_inject_transport_info_not_working_for_random_string(): + with pytest.raises(AirflowConfigException): + spark_inject_transport_info() + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SPARK_INJECT_TRANSPORT_INFO): ""}) +def test_spark_inject_transport_info_empty_conf_option(): + with pytest.raises(AirflowConfigException): + spark_inject_transport_info() + + +@conf_vars({(_CONFIG_SECTION, _CONFIG_OPTION_SPARK_INJECT_TRANSPORT_INFO): None}) +def test_spark_inject_transport_info_do_not_fail_if_conf_option_missing(): + assert spark_inject_transport_info() is False diff --git a/providers/tests/openlineage/utils/test_spark.py b/providers/tests/openlineage/utils/test_spark.py index 17eba0b55b96b..49ab0d13528e2 100644 --- a/providers/tests/openlineage/utils/test_spark.py +++ b/providers/tests/openlineage/utils/test_spark.py @@ -18,14 +18,18 @@ from __future__ import annotations import datetime as dt -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest +from openlineage.client.transport import HttpConfig, HttpTransport, KafkaConfig, KafkaTransport from airflow.providers.openlineage.utils.spark import ( _get_parent_job_information_as_spark_properties, + _get_transport_information_as_spark_properties, _is_parent_job_information_present_in_spark_properties, + _is_transport_information_present_in_spark_properties, inject_parent_job_information_into_spark_properties, + inject_transport_information_into_spark_properties, ) EXAMPLE_CONTEXT = { @@ -37,11 +41,36 @@ logical_date=dt.datetime(2024, 11, 11), ) } +EXAMPLE_HTTP_TRANSPORT_CONFIG = { + "url": "https://some-custom.url", + "endpoint": "/api/custom", + "timeout": 123, + "compression": "gzip", + "custom_headers": { + "key1": "val1", + "key2": "val2", + }, + "auth": { + "type": "api_key", + "apiKey": "secret_123", + }, +} EXAMPLE_PARENT_JOB_SPARK_PROPERTIES = { "spark.openlineage.parentJobName": "dag_id.task_id", "spark.openlineage.parentJobNamespace": "default", "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", } +EXAMPLE_TRANSPORT_SPARK_PROPERTIES = { + "spark.openlineage.transport.type": "http", + "spark.openlineage.transport.url": "https://some-custom.url", + "spark.openlineage.transport.endpoint": "/api/custom", + "spark.openlineage.transport.auth.type": "api_key", + "spark.openlineage.transport.auth.apiKey": "Bearer secret_123", + "spark.openlineage.transport.compression": "gzip", + "spark.openlineage.transport.headers.key1": "val1", + "spark.openlineage.transport.headers.key2": "val2", + "spark.openlineage.transport.timeoutInMillis": "123000", +} def test_get_parent_job_information_as_spark_properties(): @@ -49,6 +78,34 @@ def test_get_parent_job_information_as_spark_properties(): assert result == EXAMPLE_PARENT_JOB_SPARK_PROPERTIES +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +def test_get_transport_information_as_spark_properties(mock_ol_listener): + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(EXAMPLE_HTTP_TRANSPORT_CONFIG) + ) + result = _get_transport_information_as_spark_properties() + assert result == EXAMPLE_TRANSPORT_SPARK_PROPERTIES + + +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +def test_get_transport_information_as_spark_properties_unsupported_transport_type(mock_ol_listener): + kafka_config = KafkaConfig( + topic="my_topic", + config={ + "bootstrap.servers": "localhost:9092,another.host:9092", + "acks": "all", + "retries": "3", + }, + flush=True, + messageKey="some", + ) + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = KafkaTransport( + kafka_config + ) + result = _get_transport_information_as_spark_properties() + assert result == {} + + @pytest.mark.parametrize( "properties, expected", [ @@ -80,12 +137,49 @@ def test_get_parent_job_information_as_spark_properties(): }, True, ), + ( + {}, + False, + ), ], ) def test_is_parent_job_information_present_in_spark_properties(properties, expected): assert _is_parent_job_information_present_in_spark_properties(properties) is expected +@pytest.mark.parametrize( + "properties, expected", + [ + ( + {"spark.openlineage.transport": "example_namespace"}, + True, + ), + ( + {"spark.openlineage.transport.type": "some_job_name"}, + True, + ), + ( + {"spark.openlineage.transport.urlParams.value1": "some_run_id"}, + True, + ), + ( + {"spark.openlineage.transportWhatever": "some_value", "some.other.property": "value"}, + True, + ), + ( + {"some.other.property": "value"}, + False, + ), + ( + {}, + False, + ), + ], +) +def test_is_transport_information_present_in_spark_properties(properties, expected): + assert _is_transport_information_present_in_spark_properties(properties) is expected + + @pytest.mark.parametrize( "properties, should_inject", [ @@ -127,3 +221,42 @@ def test_inject_parent_job_information_into_spark_properties(properties, should_ result = inject_parent_job_information_into_spark_properties(properties, EXAMPLE_CONTEXT) expected = {**properties, **EXAMPLE_PARENT_JOB_SPARK_PROPERTIES} if should_inject else properties assert result == expected + + +@pytest.mark.parametrize( + "properties, should_inject", + [ + ( + {"spark.openlineage.transport": "example_namespace"}, + False, + ), + ( + {"spark.openlineage.transport.type": "some_job_name"}, + False, + ), + ( + {"spark.openlineage.transport.url": "some_run_id"}, + False, + ), + ( + {"spark.openlineage.transportWhatever": "some_value", "some.other.property": "value"}, + False, + ), + ( + {"some.other.property": "value"}, + True, + ), + ( + {}, + True, + ), + ], +) +@patch("airflow.providers.openlineage.plugins.listener._openlineage_listener") +def test_inject_transport_information_into_spark_properties(mock_ol_listener, properties, should_inject): + mock_ol_listener.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + HttpConfig.from_dict(EXAMPLE_HTTP_TRANSPORT_CONFIG) + ) + result = inject_transport_information_into_spark_properties(properties, EXAMPLE_CONTEXT) + expected = {**properties, **EXAMPLE_TRANSPORT_SPARK_PROPERTIES} if should_inject else properties + assert result == expected