From 751c84cd098a3b78ef8e3edb2f27629d9edfc412 Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Thu, 16 Mar 2023 22:02:10 -0300 Subject: [PATCH 01/75] Submitted python model successfully --- dbt/adapters/athena/connections.py | 11 ++ dbt/adapters/athena/impl.py | 19 ++- dbt/adapters/athena/python_submissions.py | 74 +++++++++ .../models/incremental/incremental.sql | 15 +- .../models/table/create_table_as.sql | 140 +++++++++++++----- .../materializations/models/table/table.sql | 7 +- 6 files changed, 217 insertions(+), 49 deletions(-) create mode 100644 dbt/adapters/athena/python_submissions.py diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index fb7d5f1b..3b6c3fc1 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -47,6 +47,7 @@ class AthenaCredentials(Credentials): num_retries: Optional[int] = 5 s3_data_dir: Optional[str] = None s3_data_naming: Optional[str] = "schema_table_unique" + spark_work_group: Optional[str] = None lf_tags: Optional[Dict[str, str]] = None @property @@ -70,8 +71,18 @@ def _connection_keys(self) -> Tuple[str, ...]: "s3_data_dir", "s3_data_naming", "lf_tags", + "spark_work_group", ) + def get_region_name(self) -> str: + return self.region_name + + def get_profile_name(self) -> str: + return self.aws_profile_name + + def get_spark_work_group(self) -> str: + return self.spark_work_group + class AthenaCursor(Cursor): def __init__(self, **kwargs): diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index f13d38de..467a621e 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,7 +1,7 @@ import posixpath as path from itertools import chain from threading import Lock -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union, Type from urllib.parse import urlparse from uuid import uuid4 @@ -10,11 +10,13 @@ from dbt.adapters.athena import AthenaConnectionManager from dbt.adapters.athena.config import get_boto3_config +from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper from dbt.adapters.athena.relation import AthenaRelation, AthenaSchemaSearchMap from dbt.adapters.athena.utils import clean_sql_comment -from dbt.adapters.base import Column, available +from dbt.adapters.base import PythonJobHelper, Column, available from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter +from dbt.contracts.connection import AdapterResponse from dbt.contracts.graph.manifest import Manifest from dbt.contracts.graph.nodes import CompiledNode from dbt.events import AdapterLogger @@ -617,6 +619,19 @@ def persist_docs_to_glue( glue_client.update_table(DatabaseName=relation.schema, TableInput=updated_table) + def generate_python_submission_response(self, submission_result: Any) -> AdapterResponse: + if submission_result is not None: + return AdapterResponse(_message="OK") + return AdapterResponse(_message="ERROR") + + @property + def default_python_submission_method(self) -> str: + return "athena_helper" + + @property + def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: + return {"athena_helper": AthenaPythonJobHelper} + @available def list_schemas(self, database: str) -> List[str]: conn = self.connections.get_thread_connection() diff --git a/dbt/adapters/athena/python_submissions.py b/dbt/adapters/athena/python_submissions.py new file mode 100644 index 00000000..b1e7b4ee --- /dev/null +++ b/dbt/adapters/athena/python_submissions.py @@ -0,0 +1,74 @@ +from functools import cached_property +from typing import Any, Dict + +import boto3 + +from dbt.adapters.athena.connections import AthenaCredentials +from dbt.adapters.base import PythonJobHelper + +DEFAULT_POLLING_INTERVAL = 10 +SUBMISSION_LANGUAGE = "python" +DEFAULT_TIMEOUT = 60 * 60 * 24 + + +class AthenaPythonJobHelper(PythonJobHelper): + def __init__(self, parsed_model: Dict, credentials: AthenaCredentials) -> None: + self.identifier = parsed_model["alias"] + self.schema = parsed_model["schema"] + self.parsed_model = parsed_model + self.timeout = self.get_timeout() + self.polling_interval = DEFAULT_POLLING_INTERVAL + self.region_name = credentials.get_region_name() + self.profile_name = credentials.get_profile_name() + self.spark_work_group = credentials.get_spark_work_group() + + @cached_property + def session_id(self) -> str: + if self._list_sessions() is None: + return self._start_session().get("SessionId") + return self._list_sessions().get("SessionId") + + @cached_property + def athena_client(self) -> Any: + return boto3.client("athena") + + def get_timeout(self) -> int: + timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT) + if timeout <= 0: + raise ValueError("Timeout must be a positive integer") + return timeout + + def _list_sessions(self) -> dict: + try: + response = self.athena_client.list_sessions( + WorkGroup=self.spark_work_group, MaxResults=1, StateFilter="IDLE" + ) + return response.get("Sessions")[0] + except Exception: + return None + + def _start_session(self) -> dict: + try: + response = self.athena_client.start_session( + WorkGroup=self.spark_work_group, + EngineConfiguration={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 1, "DefaultExecutorDpuSize": 18}, + ) + return response + except Exception: + return None + + def submit(self, compiled_code: str) -> dict: + try: + response = self.athena_client.start_calculation_execution( + SessionId=self.session_id, CodeBlock=compiled_code.strip() + ) + return response + except Exception: + return None + + def _terminate_session(self) -> dict: + try: + response = self.athena_client.terminate_session(SessionId=self.session_id) + return response + except Exception: + return None diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index 865af83d..c2400c15 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -1,4 +1,5 @@ -{% materialization incremental, adapter='athena' -%} +{% materialization incremental, adapter='athena', supported_languages=['sql', 'python'] -%} + {%- set language = model['language'] -%} {% set raw_strategy = config.get('incremental_strategy') or 'insert_overwrite' %} {% set table_type = config.get('table_type', default='hive') | lower %} @@ -24,16 +25,16 @@ {% set to_drop = [] %} {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, sql) -%} + {% set build_sql = create_table_as(False, target_relation, sql, language) -%} {% elif existing_relation.is_view or should_full_refresh() %} {% do drop_relation(existing_relation) %} - {% set build_sql = create_table_as(False, target_relation, sql) -%} + {% set build_sql = create_table_as(False, target_relation, sql, language) -%} {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} {% set tmp_relation = make_temp_relation(target_relation) %} {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% do run_query(create_table_as(True, tmp_relation, sql, language)) %} {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} @@ -42,7 +43,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% do run_query(create_table_as(True, tmp_relation, sql, language)) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% elif strategy == 'merge' and table_type == 'iceberg' %} @@ -58,12 +59,12 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql)) %} + {% do run_query(create_table_as(True, tmp_relation, sql, language)) %} {% set build_sql = iceberg_merge(on_schema_change, tmp_relation, target_relation, unique_key, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% endif %} - {% call statement("main") %} + {% call statement("main", language=language) %} {{ build_sql }} {% endcall %} diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index 49b3ce4e..4f42cb51 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -1,49 +1,57 @@ -{% macro athena__create_table_as(temporary, relation, sql) -%} - {%- set materialized = config.get('materialized', default='table') -%} - {%- set external_location = config.get('external_location', default=none) -%} - {%- set partitioned_by = config.get('partitioned_by', default=none) -%} - {%- set bucketed_by = config.get('bucketed_by', default=none) -%} - {%- set bucket_count = config.get('bucket_count', default=none) -%} - {%- set field_delimiter = config.get('field_delimiter', default=none) -%} - {%- set table_type = config.get('table_type', default='hive') | lower -%} - {%- set format = config.get('format', default='parquet') -%} - {%- set write_compression = config.get('write_compression', default=none) -%} - {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} - {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} - {%- set extra_table_properties = config.get('table_properties', default=none) -%} +{% macro athena__create_table_as(temporary, relation, compiled_code, language='sql') -%} + {%- if language == 'sql' -%} + {%- set materialized = config.get('materialized', default='table') -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- set partitioned_by = config.get('partitioned_by', default=none) -%} + {%- set bucketed_by = config.get('bucketed_by', default=none) -%} + {%- set bucket_count = config.get('bucket_count', default=none) -%} + {%- set field_delimiter = config.get('field_delimiter', default=none) -%} + {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set format = config.get('format', default='parquet') -%} + {%- set write_compression = config.get('write_compression', default=none) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} + {%- set extra_table_properties = config.get('table_properties', default=none) -%} +<<<<<<< HEAD {%- set location_property = 'external_location' -%} {%- set partition_property = 'partitioned_by' -%} {%- set work_group_output_location = adapter.get_work_group_output_location() -%} {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} +======= + {%- set location_property = 'external_location' -%} + {%- set partition_property = 'partitioned_by' -%} + {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} +>>>>>>> 10dd892 (Submitted python model successfully) - {%- if materialized == 'table_hive_ha' -%} - {%- set location = location.replace('__ha', '') -%} - {%- endif %} + {%- if materialized == 'table_hive_ha' -%} + {%- set location = location.replace('__ha', '') -%} + {%- endif %} - {%- if table_type == 'iceberg' -%} - {%- set location_property = 'location' -%} - {%- set partition_property = 'partitioning' -%} - {%- if bucketed_by is not none or bucket_count is not none -%} - {%- set ignored_bucket_iceberg -%} - bucketed_by or bucket_count cannot be used with Iceberg tables. You have to use the bucket function - when partitioning. Will be ignored - {%- endset -%} - {%- set bucketed_by = none -%} - {%- set bucket_count = none -%} - {% do log(ignored_bucket_iceberg) %} - {%- endif -%} - {%- if s3_data_naming in ['table', 'table_schema'] or external_location is not none -%} - {%- set error_unique_location_iceberg -%} - You need to have an unique table location when creating Iceberg table. Right now we are building tables in - a destructive way but in the near future we will be using the RENAME feature to provide near-zero downtime. - {%- endset -%} - {% do exceptions.raise_compiler_error(error_unique_location_iceberg) %} - {%- endif -%} - {%- endif %} + {%- if table_type == 'iceberg' -%} + {%- set location_property = 'location' -%} + {%- set partition_property = 'partitioning' -%} + {%- if bucketed_by is not none or bucket_count is not none -%} + {%- set ignored_bucket_iceberg -%} + bucketed_by or bucket_count cannot be used with Iceberg tables. You have to use the bucket function + when partitioning. Will be ignored + {%- endset -%} + {%- set bucketed_by = none -%} + {%- set bucket_count = none -%} + {% do log(ignored_bucket_iceberg) %} + {%- endif -%} + {%- if s3_data_naming in ['table', 'table_schema'] or external_location is not none -%} + {%- set error_unique_location_iceberg -%} + You need to have an unique table location when creating Iceberg table. Right now we are building tables in + a destructive way but in the near future we will be using the RENAME feature to provide near-zero downtime. + {%- endset -%} + {% do exceptions.raise_compiler_error(error_unique_location_iceberg) %} + {%- endif -%} + {%- endif %} - {% do adapter.delete_from_s3(location) %} + {% do adapter.delete_from_s3(location) %} +<<<<<<< HEAD create table {{ relation }} with ( table_type='{{ table_type }}', @@ -77,3 +85,61 @@ as {{ sql }} {% endmacro %} +======= + create table {{ relation }} + with ( + table_type='{{ table_type }}', + is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %}, + {{ location_property }}='{{ location }}', + {%- if partitioned_by is not none %} + {{ partition_property }}=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }}, + {%- endif %} + {%- if bucketed_by is not none %} + bucketed_by=ARRAY{{ bucketed_by | tojson | replace('\"', '\'') }}, + {%- endif %} + {%- if bucket_count is not none %} + bucket_count={{ bucket_count }}, + {%- endif %} + {%- if field_delimiter is not none %} + field_delimiter='{{ field_delimiter }}', + {%- endif %} + {%- if write_compression is not none %} + write_compression='{{ write_compression }}', + {%- endif %} + format='{{ format }}' + {%- if extra_table_properties is not none -%} + {%- for prop_name, prop_value in extra_table_properties.items() -%} + , + {{ prop_name }}={{ prop_value }} + {%- endfor -%} + {% endif %} + ) + as + {{ compiled_code }} + {%- elif language == 'python' -%} + {{ athena__py_create_table_as(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }} + {%- else -%} + {% do exceptions.raise_compiler_error("athena__create_table_as macro doesn't support the provided language, it got %s" % language) %} + {%- endif -%} +{%- endmacro -%} + +{%- macro athena__py_create_table_as(compiled_code, target_relation, temporary) -%} +{{ compiled_code }} +def materialize(session, df, target_relation): + # make sure pandas exists + import importlib.util + package_name = 'pandas' + if importlib.util.find_spec(package_name): + import pandas + if isinstance(df, pandas.core.frame.DataFrame): + # session.write_pandas does not have overwrite function + df = session.createDataFrame(df) + df.write.mode("overwrite").save_as_table('{{ target_relation.identifier }}', create_temp_table={{temporary}}) + +def main(session): + dbt = dbtObj(session.table) + df = model(dbt, session) + materialize(session, df, dbt.this) + return "OK" +{%- endmacro -%} +>>>>>>> 10dd892 (Submitted python model successfully) diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index ce995e99..34646220 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -1,5 +1,6 @@ -{% materialization table, adapter='athena' -%} +{% materialization table, adapter='athena', supported_languages=['sql', 'python'] -%} {%- set identifier = model['alias'] -%} + {%- set language = model['language'] -%} {%- set lf_tags = config.get('lf_tags', default=none) -%} {%- set lf_tags_columns = config.get('lf_tags_columns', default=none) -%} @@ -18,8 +19,8 @@ {%- endif -%} -- build model - {% call statement('main') -%} - {{ create_table_as(False, target_relation, sql) }} + {% call statement('main', language=language) -%} + {{ create_table_as(False, target_relation, compiled_code, language) }} {%- endcall %} {% if table_type != 'iceberg' %} From 773700077854a3c4631afb262d4c81f4e1d62648 Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Sat, 25 Mar 2023 15:47:01 -0300 Subject: [PATCH 02/75] Rebase and resolved conflicts --- .../models/table/create_table_as.sql | 44 +------------------ 1 file changed, 1 insertion(+), 43 deletions(-) diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index 4f42cb51..f7a96997 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -13,16 +13,10 @@ {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} {%- set extra_table_properties = config.get('table_properties', default=none) -%} -<<<<<<< HEAD - {%- set location_property = 'external_location' -%} - {%- set partition_property = 'partitioned_by' -%} - {%- set work_group_output_location = adapter.get_work_group_output_location() -%} - {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} -======= {%- set location_property = 'external_location' -%} {%- set partition_property = 'partitioned_by' -%} + {%- set work_group_output_location = adapter.get_work_group_output_location() -%} {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} ->>>>>>> 10dd892 (Submitted python model successfully) {%- if materialized == 'table_hive_ha' -%} {%- set location = location.replace('__ha', '') -%} @@ -51,41 +45,6 @@ {% do adapter.delete_from_s3(location) %} -<<<<<<< HEAD - create table {{ relation }} - with ( - table_type='{{ table_type }}', - is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %}, - {%- if work_group_output_location is none -%} - {{ location_property }}='{{ location }}', - {%- endif %} - {%- if partitioned_by is not none %} - {{ partition_property }}=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }}, - {%- endif %} - {%- if bucketed_by is not none %} - bucketed_by=ARRAY{{ bucketed_by | tojson | replace('\"', '\'') }}, - {%- endif %} - {%- if bucket_count is not none %} - bucket_count={{ bucket_count }}, - {%- endif %} - {%- if field_delimiter is not none %} - field_delimiter='{{ field_delimiter }}', - {%- endif %} - {%- if write_compression is not none %} - write_compression='{{ write_compression }}', - {%- endif %} - format='{{ format }}' - {%- if extra_table_properties is not none -%} - {%- for prop_name, prop_value in extra_table_properties.items() -%} - , - {{ prop_name }}={{ prop_value }} - {%- endfor -%} - {% endif %} - ) - as - {{ sql }} -{% endmacro %} -======= create table {{ relation }} with ( table_type='{{ table_type }}', @@ -142,4 +101,3 @@ def main(session): materialize(session, df, dbt.this) return "OK" {%- endmacro -%} ->>>>>>> 10dd892 (Submitted python model successfully) From c1d1ab24e93f7a029b03bc9bf04ad13387df6d95 Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Sun, 26 Mar 2023 01:17:03 -0300 Subject: [PATCH 03/75] Execution successful but table not saved --- dbt/adapters/athena/impl.py | 10 +-- dbt/adapters/athena/python_submissions.py | 73 ++++++++++++++++--- .../models/table/create_table_as.sql | 26 ++++--- .../materializations/models/table/table.sql | 2 +- 4 files changed, 81 insertions(+), 30 deletions(-) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 467a621e..330077f6 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -1,7 +1,7 @@ import posixpath as path from itertools import chain from threading import Lock -from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union, Type +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Type, Union from urllib.parse import urlparse from uuid import uuid4 @@ -13,7 +13,7 @@ from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper from dbt.adapters.athena.relation import AthenaRelation, AthenaSchemaSearchMap from dbt.adapters.athena.utils import clean_sql_comment -from dbt.adapters.base import PythonJobHelper, Column, available +from dbt.adapters.base import Column, PythonJobHelper, available from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter from dbt.contracts.connection import AdapterResponse @@ -620,9 +620,9 @@ def persist_docs_to_glue( glue_client.update_table(DatabaseName=relation.schema, TableInput=updated_table) def generate_python_submission_response(self, submission_result: Any) -> AdapterResponse: - if submission_result is not None: - return AdapterResponse(_message="OK") - return AdapterResponse(_message="ERROR") + if submission_result is None: + return AdapterResponse(_message="ERROR") + return AdapterResponse(_message="OK") @property def default_python_submission_method(self) -> str: diff --git a/dbt/adapters/athena/python_submissions.py b/dbt/adapters/athena/python_submissions.py index b1e7b4ee..8afc8f0f 100644 --- a/dbt/adapters/athena/python_submissions.py +++ b/dbt/adapters/athena/python_submissions.py @@ -1,3 +1,4 @@ +import time from functools import cached_property from typing import Any, Dict @@ -5,11 +6,15 @@ from dbt.adapters.athena.connections import AthenaCredentials from dbt.adapters.base import PythonJobHelper +from dbt.events import AdapterLogger +from dbt.exceptions import DbtRuntimeError -DEFAULT_POLLING_INTERVAL = 10 +DEFAULT_POLLING_INTERVAL = 2 SUBMISSION_LANGUAGE = "python" DEFAULT_TIMEOUT = 60 * 60 * 24 +logger = AdapterLogger("Athena") + class AthenaPythonJobHelper(PythonJobHelper): def __init__(self, parsed_model: Dict, credentials: AthenaCredentials) -> None: @@ -43,32 +48,76 @@ def _list_sessions(self) -> dict: response = self.athena_client.list_sessions( WorkGroup=self.spark_work_group, MaxResults=1, StateFilter="IDLE" ) + if len(response.get("Sessions")) == 0 or response.get("Sessions") is None: + return None return response.get("Sessions")[0] except Exception: - return None + raise def _start_session(self) -> dict: try: response = self.athena_client.start_session( WorkGroup=self.spark_work_group, - EngineConfiguration={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 1, "DefaultExecutorDpuSize": 18}, + EngineConfiguration={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}, ) + if response["State"] != "IDLE": + self._poll_until_session_creation(response["SessionId"]) return response except Exception: - return None + raise def submit(self, compiled_code: str) -> dict: try: - response = self.athena_client.start_calculation_execution( - SessionId=self.session_id, CodeBlock=compiled_code.strip() - ) - return response + calculation_execution_id = self.athena_client.start_calculation_execution( + SessionId=self.session_id, CodeBlock=compiled_code.lstrip() + )["CalculationExecutionId"] + logger.debug(f"Submitted calculation execution id {calculation_execution_id}") + execution_status = self._poll_until_execution_completion(calculation_execution_id) + logger.debug(f"Received execution status {execution_status}") + if execution_status == "COMPLETED": + result_s3_uri = self.athena_client.get_calculation_execution( + CalculationExecutionId=calculation_execution_id + )["Result"]["ResultS3Uri"] + return result_s3_uri + else: + raise DbtRuntimeError(f"python model run ended in state {execution_status}") except Exception: - return None + raise def _terminate_session(self) -> dict: try: - response = self.athena_client.terminate_session(SessionId=self.session_id) - return response + self.athena_client.terminate_session(SessionId=self.session_id) except Exception: - return None + raise + + def _poll_until_execution_completion(self, calculation_execution_id): + polling_interval = self.polling_interval + while True: + execution_status = self.athena_client.get_calculation_execution_status( + CalculationExecutionId=calculation_execution_id + )["Status"]["State"] + if execution_status in ["COMPLETED", "FAILED", "CANCELLED"]: + return execution_status + time.sleep(polling_interval) + polling_interval *= 2 + if polling_interval > self.timeout: + raise DbtRuntimeError( + f"Execution {calculation_execution_id} did not complete within {self.timeout} seconds." + ) + + def _poll_until_session_creation(self, session_id): + polling_interval = self.polling_interval + while True: + creation_status = self.athena_client.get_session_status(SessionId=session_id)["Status"]["State"] + if creation_status in ["FAILED", "TERMINATED", "DEGRADED"]: + raise DbtRuntimeError(f"Unable to create session: {session_id}. Got status: {creation_status}.") + elif creation_status == "IDLE": + return creation_status + time.sleep(polling_interval) + polling_interval *= 2 + if polling_interval > self.timeout: + raise DbtRuntimeError(f"Session {session_id} did not create within {self.timeout} seconds.") + + def __del__(self) -> None: + logger.debug(f"Terminating session: {self.session_id}") + self._terminate_session() diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index f7a96997..978b4c92 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -51,10 +51,10 @@ is_external={%- if table_type == 'iceberg' -%}false{%- else -%}true{%- endif %}, {{ location_property }}='{{ location }}', {%- if partitioned_by is not none %} - {{ partition_property }}=ARRAY{{ partitioned_by | tojson | replace('\"', '\'') }}, + {{ partition_property }}=ARRAY{{ partitioned_by | join("', '") | replace('"', "'") | prepend("'") | append("'") }}, {%- endif %} {%- if bucketed_by is not none %} - bucketed_by=ARRAY{{ bucketed_by | tojson | replace('\"', '\'') }}, + bucketed_by=ARRAY{{ bucketed_by | join("', '") | replace('"', "'") | prepend("'") | append("'") }}, {%- endif %} {%- if bucket_count is not none %} bucket_count={{ bucket_count }}, @@ -76,24 +76,26 @@ as {{ compiled_code }} {%- elif language == 'python' -%} - {{ athena__py_create_table_as(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }} + {{ athena__py_create_table_as(compiled_code=compiled_code, target_relation=relation, temporary=temporary) | trim }} {%- else -%} {% do exceptions.raise_compiler_error("athena__create_table_as macro doesn't support the provided language, it got %s" % language) %} {%- endif -%} {%- endmacro -%} {%- macro athena__py_create_table_as(compiled_code, target_relation, temporary) -%} -{{ compiled_code }} +{{ compiled_code | trim }} def materialize(session, df, target_relation): - # make sure pandas exists - import importlib.util - package_name = 'pandas' - if importlib.util.find_spec(package_name): - import pandas + import pandas + try: if isinstance(df, pandas.core.frame.DataFrame): - # session.write_pandas does not have overwrite function - df = session.createDataFrame(df) - df.write.mode("overwrite").save_as_table('{{ target_relation.identifier }}', create_temp_table={{temporary}}) + df = spark.createDataFrame(df) + df.write.saveAsTable( + name="{{ target_relation.schema}}.{{ target_relation.identifier }}", + format="parquet", + mode="overwrite" + ) + except Exception: + raise def main(session): dbt = dbtObj(session.table) diff --git a/dbt/include/athena/macros/materializations/models/table/table.sql b/dbt/include/athena/macros/materializations/models/table/table.sql index 34646220..79c7ba9c 100644 --- a/dbt/include/athena/macros/materializations/models/table/table.sql +++ b/dbt/include/athena/macros/materializations/models/table/table.sql @@ -23,7 +23,7 @@ {{ create_table_as(False, target_relation, compiled_code, language) }} {%- endcall %} - {% if table_type != 'iceberg' %} + {% if table_type != 'iceberg' and language != 'python' %} {{ set_table_classification(target_relation) }} {% endif %} From fd6a135e7be4a3edf49000c7b3a38ebe9a3f392c Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Mon, 27 Mar 2023 10:02:12 -0300 Subject: [PATCH 04/75] Add incremental model support --- .../models/incremental/incremental.sql | 10 +++++----- .../models/table/create_table_as.sql | 19 +++++++++---------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index c2400c15..6a6acb45 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -25,16 +25,16 @@ {% set to_drop = [] %} {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, sql, language) -%} + {% set build_sql = create_table_as(False, target_relation, compiled_code, language) -%} {% elif existing_relation.is_view or should_full_refresh() %} {% do drop_relation(existing_relation) %} - {% set build_sql = create_table_as(False, target_relation, sql, language) -%} + {% set build_sql = create_table_as(False, target_relation, compiled_code, language) -%} {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} {% set tmp_relation = make_temp_relation(target_relation) %} {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql, language)) %} + {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} @@ -43,7 +43,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql, language)) %} + {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% elif strategy == 'merge' and table_type == 'iceberg' %} @@ -59,7 +59,7 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, sql, language)) %} + {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} {% set build_sql = iceberg_merge(on_schema_change, tmp_relation, target_relation, unique_key, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% endif %} diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index 978b4c92..cf58d47c 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -76,30 +76,29 @@ as {{ compiled_code }} {%- elif language == 'python' -%} - {{ athena__py_create_table_as(compiled_code=compiled_code, target_relation=relation, temporary=temporary) | trim }} + {{ athena__py_create_table_as(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }} {%- else -%} {% do exceptions.raise_compiler_error("athena__create_table_as macro doesn't support the provided language, it got %s" % language) %} {%- endif -%} {%- endmacro -%} {%- macro athena__py_create_table_as(compiled_code, target_relation, temporary) -%} -{{ compiled_code | trim }} -def materialize(session, df, target_relation): +{{ compiled_code }} +def materialize(spark_session, df, target_relation): import pandas try: if isinstance(df, pandas.core.frame.DataFrame): - df = spark.createDataFrame(df) + df = spark_session.createDataFrame(df) df.write.saveAsTable( name="{{ target_relation.schema}}.{{ target_relation.identifier }}", format="parquet", - mode="overwrite" + mode="overwrite", ) + return "OK" except Exception: raise -def main(session): - dbt = dbtObj(session.table) - df = model(dbt, session) - materialize(session, df, dbt.this) - return "OK" +dbt = dbtObj(spark.table) +df = model(dbt, spark) +materialize(spark, df, dbt.this) {%- endmacro -%} From 7913646da68b3ed5c6efffc68edf161021187962 Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Mon, 3 Apr 2023 20:47:24 -0300 Subject: [PATCH 05/75] Fixed location and incremental model rerun --- .env.example | 1 + dbt/adapters/athena/connections.py | 9 -- dbt/adapters/athena/python_submissions.py | 30 +++-- .../macros/adapters/python_submissions.sql | 35 ++++++ .../models/incremental/incremental.sql | 45 +++++++- .../models/table/create_table_as.sql | 106 +++++++----------- tests/conftest.py | 1 + tests/unit/constants.py | 1 + tests/unit/test_python_submissions.py | 56 +++++++++ 9 files changed, 194 insertions(+), 90 deletions(-) create mode 100644 dbt/include/athena/macros/adapters/python_submissions.sql create mode 100644 tests/unit/test_python_submissions.py diff --git a/.env.example b/.env.example index cfda7881..520991b3 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,4 @@ DBT_TEST_ATHENA_DATABASE= DBT_TEST_ATHENA_SCHEMA= DBT_TEST_ATHENA_WORK_GROUND= DBT_TEST_ATHENA_AWS_PROFILE_NAME= +DBT_TEST_ATHENA_SPARK_WORK_GROUP= diff --git a/dbt/adapters/athena/connections.py b/dbt/adapters/athena/connections.py index 421b00a9..a4ed8911 100644 --- a/dbt/adapters/athena/connections.py +++ b/dbt/adapters/athena/connections.py @@ -74,15 +74,6 @@ def _connection_keys(self) -> Tuple[str, ...]: "spark_work_group", ) - def get_region_name(self) -> str: - return self.region_name - - def get_profile_name(self) -> str: - return self.aws_profile_name - - def get_spark_work_group(self) -> str: - return self.spark_work_group - class AthenaCursor(Cursor): def __init__(self, **kwargs): diff --git a/dbt/adapters/athena/python_submissions.py b/dbt/adapters/athena/python_submissions.py index 8afc8f0f..8ef7b066 100644 --- a/dbt/adapters/athena/python_submissions.py +++ b/dbt/adapters/athena/python_submissions.py @@ -1,5 +1,6 @@ import time -from functools import cached_property +from datetime import datetime, timedelta, timezone +from functools import lru_cache from typing import Any, Dict import boto3 @@ -11,7 +12,7 @@ DEFAULT_POLLING_INTERVAL = 2 SUBMISSION_LANGUAGE = "python" -DEFAULT_TIMEOUT = 60 * 60 * 24 +DEFAULT_TIMEOUT = 60 * 60 * 2 logger = AdapterLogger("Athena") @@ -23,17 +24,20 @@ def __init__(self, parsed_model: Dict, credentials: AthenaCredentials) -> None: self.parsed_model = parsed_model self.timeout = self.get_timeout() self.polling_interval = DEFAULT_POLLING_INTERVAL - self.region_name = credentials.get_region_name() - self.profile_name = credentials.get_profile_name() - self.spark_work_group = credentials.get_spark_work_group() + self.region_name = credentials.region_name + self.profile_name = credentials.aws_profile_name + self.spark_work_group = credentials.spark_work_group - @cached_property + @property + @lru_cache() def session_id(self) -> str: - if self._list_sessions() is None: + session_info = self._list_sessions() + if session_info is None: return self._start_session().get("SessionId") - return self._list_sessions().get("SessionId") + return session_info.get("SessionId") - @cached_property + @property + @lru_cache() def athena_client(self) -> Any: return boto3.client("athena") @@ -86,7 +90,12 @@ def submit(self, compiled_code: str) -> dict: def _terminate_session(self) -> dict: try: - self.athena_client.terminate_session(SessionId=self.session_id) + session_status = self.athena_client.get_session_status(SessionId=self.session_id)["Status"] + if session_status["State"] in ["IDLE", "BUSY"] and ( + session_status["StartDateTime"] - datetime.now(tz=timezone.utc) > timedelta(seconds=self.timeout) + ): + logger.debug(f"Terminating session: {self.session_id}") + self.athena_client.terminate_session(SessionId=self.session_id) except Exception: raise @@ -119,5 +128,4 @@ def _poll_until_session_creation(self, session_id): raise DbtRuntimeError(f"Session {session_id} did not create within {self.timeout} seconds.") def __del__(self) -> None: - logger.debug(f"Terminating session: {self.session_id}") self._terminate_session() diff --git a/dbt/include/athena/macros/adapters/python_submissions.sql b/dbt/include/athena/macros/adapters/python_submissions.sql new file mode 100644 index 00000000..c9c02d24 --- /dev/null +++ b/dbt/include/athena/macros/adapters/python_submissions.sql @@ -0,0 +1,35 @@ +{%- macro athena__py_save_table_as(compiled_code, target_relation, format, location, mode="overwrite") -%} +{{ compiled_code }} +def materialize(spark_session, df, target_relation): + import pandas + try: + if isinstance(df, pandas.core.frame.DataFrame): + df = spark_session.createDataFrame(df) + df.write \ + .format("{{ format }}") \ + .option("path", "{{ location }}") \ + .mode("{{ mode }}") \ + .saveAsTable( + name="{{ target_relation.schema}}.{{ target_relation.identifier }}", + ) + return "OK" + except Exception: + raise + +dbt = dbtObj(spark.table) +df = model(dbt, spark) +materialize(spark, df, dbt.this) +{%- endmacro -%} + +{%- macro athena__py_execute_query(query) -%} +def execute_query(spark_session): + import pandas + try: + spark_session.sql("""{{ query }}""") + return "OK" + except Exception: + raise + +dbt = dbtObj(spark.table) +execute_query(spark) +{%- endmacro -%} diff --git a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql index b6724cc2..309bf87e 100644 --- a/dbt/include/athena/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/athena/macros/materializations/models/incremental/incremental.sql @@ -13,6 +13,8 @@ {% set existing_relation = load_relation(this) %} {% set tmp_relation = make_temp_relation(this) %} + {{ log("temporary relation is" ~ tmp_relation.schema ~ tmp_relation.identifier)}} + -- If no partitions are used with insert_overwrite, we fall back to append mode. {% if partitioned_by is none and strategy == 'insert_overwrite' %} {% set strategy = 'append' %} @@ -25,16 +27,28 @@ {% set to_drop = [] %} {% if existing_relation is none %} - {% set build_sql = create_table_as(False, target_relation, compiled_code, language) -%} + {% call statement('save_table', language=language) -%} + {{ create_table_as(False, target_relation, compiled_code, language) }} + {%- endcall %} + {% set build_sql = None %} {% elif existing_relation.is_view or should_full_refresh() %} {% do drop_relation(existing_relation) %} - {% set build_sql = create_table_as(False, target_relation, compiled_code, language) -%} + {% call statement('save_table', language=language) -%} + {{ create_table_as(False, target_relation, compiled_code, language) }} + {%- endcall %} + {% set build_sql = None %} {% elif partitioned_by is not none and strategy == 'insert_overwrite' %} {% set tmp_relation = make_temp_relation(target_relation) %} {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} + {% if language == 'sql' %} + {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} + {% else %} + {% call statement('save_table', language=language) -%} + {{ create_table_as(True, tmp_relation, compiled_code, language) }} + {%- endcall %} + {% endif %} {% do delete_overlapping_partitions(target_relation, tmp_relation, partitioned_by) %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} @@ -43,7 +57,13 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} + {% if language == 'sql' %} + {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} + {% else %} + {% call statement('save_table', language=language) -%} + {{ create_table_as(True, tmp_relation, compiled_code, language) }} + {%- endcall %} + {% endif %} {% set build_sql = incremental_insert(on_schema_change, tmp_relation, target_relation, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% elif strategy == 'merge' and table_type == 'iceberg' %} @@ -59,13 +79,26 @@ {% if tmp_relation is not none %} {% do drop_relation(tmp_relation) %} {% endif %} - {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} + {% if language == 'sql' %} + {% do run_query(create_table_as(True, tmp_relation, compiled_code, language)) %} + {% else %} + {% call statement('save_table', language=language) -%} + {{ create_table_as(True, tmp_relation, compiled_code, language) }} + {%- endcall %} + {% endif %} {% set build_sql = iceberg_merge(on_schema_change, tmp_relation, target_relation, unique_key, existing_relation) %} {% do to_drop.append(tmp_relation) %} {% endif %} {% call statement("main", language=language) %} - {{ build_sql }} + {% if language == 'sql' %} + {{ build_sql }} + {% else %} + {% if build_sql %} + {{ log(build_sql) }} + {% do athena__py_execute_query(query=build_sql) %} + {% endif %} + {% endif %} {% endcall %} -- set table properties diff --git a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql index cf58d47c..5205a14b 100644 --- a/dbt/include/athena/macros/materializations/models/table/create_table_as.sql +++ b/dbt/include/athena/macros/materializations/models/table/create_table_as.sql @@ -1,50 +1,49 @@ {% macro athena__create_table_as(temporary, relation, compiled_code, language='sql') -%} - {%- if language == 'sql' -%} - {%- set materialized = config.get('materialized', default='table') -%} - {%- set external_location = config.get('external_location', default=none) -%} - {%- set partitioned_by = config.get('partitioned_by', default=none) -%} - {%- set bucketed_by = config.get('bucketed_by', default=none) -%} - {%- set bucket_count = config.get('bucket_count', default=none) -%} - {%- set field_delimiter = config.get('field_delimiter', default=none) -%} - {%- set table_type = config.get('table_type', default='hive') | lower -%} - {%- set format = config.get('format', default='parquet') -%} - {%- set write_compression = config.get('write_compression', default=none) -%} - {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} - {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} - {%- set extra_table_properties = config.get('table_properties', default=none) -%} - - {%- set location_property = 'external_location' -%} - {%- set partition_property = 'partitioned_by' -%} - {%- set work_group_output_location = adapter.get_work_group_output_location() -%} - {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} + {%- set materialized = config.get('materialized', default='table') -%} + {%- set external_location = config.get('external_location', default=none) -%} + {%- set partitioned_by = config.get('partitioned_by', default=none) -%} + {%- set bucketed_by = config.get('bucketed_by', default=none) -%} + {%- set bucket_count = config.get('bucket_count', default=none) -%} + {%- set field_delimiter = config.get('field_delimiter', default=none) -%} + {%- set table_type = config.get('table_type', default='hive') | lower -%} + {%- set format = config.get('format', default='parquet') -%} + {%- set write_compression = config.get('write_compression', default=none) -%} + {%- set s3_data_dir = config.get('s3_data_dir', default=target.s3_data_dir) -%} + {%- set s3_data_naming = config.get('s3_data_naming', default=target.s3_data_naming) -%} + {%- set extra_table_properties = config.get('table_properties', default=none) -%} - {%- if materialized == 'table_hive_ha' -%} - {%- set location = location.replace('__ha', '') -%} - {%- endif %} + {%- set location_property = 'external_location' -%} + {%- set partition_property = 'partitioned_by' -%} + {%- set work_group_output_location = adapter.get_work_group_output_location() -%} + {%- set location = adapter.s3_table_location(s3_data_dir, s3_data_naming, relation.schema, relation.identifier, external_location, temporary) -%} - {%- if table_type == 'iceberg' -%} - {%- set location_property = 'location' -%} - {%- set partition_property = 'partitioning' -%} - {%- if bucketed_by is not none or bucket_count is not none -%} - {%- set ignored_bucket_iceberg -%} - bucketed_by or bucket_count cannot be used with Iceberg tables. You have to use the bucket function - when partitioning. Will be ignored - {%- endset -%} - {%- set bucketed_by = none -%} - {%- set bucket_count = none -%} - {% do log(ignored_bucket_iceberg) %} - {%- endif -%} - {%- if s3_data_naming in ['table', 'table_schema'] or external_location is not none -%} - {%- set error_unique_location_iceberg -%} - You need to have an unique table location when creating Iceberg table. Right now we are building tables in - a destructive way but in the near future we will be using the RENAME feature to provide near-zero downtime. - {%- endset -%} - {% do exceptions.raise_compiler_error(error_unique_location_iceberg) %} - {%- endif -%} - {%- endif %} + {%- if materialized == 'table_hive_ha' -%} + {%- set location = location.replace('__ha', '') -%} + {%- endif %} - {% do adapter.delete_from_s3(location) %} + {%- if table_type == 'iceberg' -%} + {%- set location_property = 'location' -%} + {%- set partition_property = 'partitioning' -%} + {%- if bucketed_by is not none or bucket_count is not none -%} + {%- set ignored_bucket_iceberg -%} + bucketed_by or bucket_count cannot be used with Iceberg tables. You have to use the bucket function + when partitioning. Will be ignored + {%- endset -%} + {%- set bucketed_by = none -%} + {%- set bucket_count = none -%} + {% do log(ignored_bucket_iceberg) %} + {%- endif -%} + {%- if s3_data_naming in ['table', 'table_schema'] or external_location is not none -%} + {%- set error_unique_location_iceberg -%} + You need to have an unique table location when creating Iceberg table. Right now we are building tables in + a destructive way but in the near future we will be using the RENAME feature to provide near-zero downtime. + {%- endset -%} + {% do exceptions.raise_compiler_error(error_unique_location_iceberg) %} + {%- endif -%} + {%- endif %} + {% do adapter.delete_from_s3(location) %} + {%- if language == 'sql' -%} create table {{ relation }} with ( table_type='{{ table_type }}', @@ -76,29 +75,8 @@ as {{ compiled_code }} {%- elif language == 'python' -%} - {{ athena__py_create_table_as(compiled_code=compiled_code, target_relation=relation, temporary=temporary) }} + {{ athena__py_save_table_as(compiled_code=compiled_code, target_relation=relation, format=format, location=location, mode="overwrite") }} {%- else -%} {% do exceptions.raise_compiler_error("athena__create_table_as macro doesn't support the provided language, it got %s" % language) %} {%- endif -%} {%- endmacro -%} - -{%- macro athena__py_create_table_as(compiled_code, target_relation, temporary) -%} -{{ compiled_code }} -def materialize(spark_session, df, target_relation): - import pandas - try: - if isinstance(df, pandas.core.frame.DataFrame): - df = spark_session.createDataFrame(df) - df.write.saveAsTable( - name="{{ target_relation.schema}}.{{ target_relation.identifier }}", - format="parquet", - mode="overwrite", - ) - return "OK" - except Exception: - raise - -dbt = dbtObj(spark.table) -df = model(dbt, spark) -materialize(spark, df, dbt.this) -{%- endmacro -%} diff --git a/tests/conftest.py b/tests/conftest.py index 2233c410..fba8b34a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -26,6 +26,7 @@ def dbt_profile_target(): "num_retries": 0, "work_group": os.getenv("DBT_TEST_ATHENA_WORK_GROUND"), "aws_profile_name": os.getenv("DBT_TEST_ATHENA_AWS_PROFILE_NAME") or None, + "spark_work_group": os.getenv("DBT_TEST_ATHENA_SPARK_WORK_GROUP"), } diff --git a/tests/unit/constants.py b/tests/unit/constants.py index 4178f9be..7a7d01ac 100644 --- a/tests/unit/constants.py +++ b/tests/unit/constants.py @@ -5,3 +5,4 @@ AWS_REGION = "eu-west-1" S3_STAGING_DIR = "s3://my-bucket/test-dbt/" ATHENA_WORKGROUP = "dbt-athena-adapter" +SPARK_WORKGROUP = "spark" diff --git a/tests/unit/test_python_submissions.py b/tests/unit/test_python_submissions.py new file mode 100644 index 00000000..56069711 --- /dev/null +++ b/tests/unit/test_python_submissions.py @@ -0,0 +1,56 @@ +from unittest.mock import patch + +from dbt.adapters.athena.connections import AthenaCredentials +from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper + +from .constants import ( + ATHENA_WORKGROUP, + AWS_REGION, + DATA_CATALOG_NAME, + DATABASE_NAME, + S3_STAGING_DIR, + SPARK_WORKGROUP, +) +from .utils import MockAWSService + + +class TestPythonSubmission: + mock_aws_service = MockAWSService() + parsed_model = {"alias": "test_model", "schema": DATABASE_NAME, "config": {"timeout": 10}} + _athena_job_helper = None + credentials = AthenaCredentials( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + s3_staging_dir=S3_STAGING_DIR, + region_name=AWS_REGION, + work_group=ATHENA_WORKGROUP, + spark_work_group=SPARK_WORKGROUP, + ) + + @property + def athena_job_helper(self): + if self._athena_job_helper is None: + self._athena_job_helper = AthenaPythonJobHelper(self.parsed_model, self.credentials) + return self._athena_job_helper + + def test_create_session(self): + obj = self.athena_job_helper + + # Define the mock return values for the _list_sessions and _start_session methods + mock_list_sessions = {"SessionId": "session123"} + mock_start_session = {"SessionId": "session456"} + + # Mock the _list_sessions and _start_session methods using the patch decorator + with patch.object(obj, "_list_sessions", return_value=mock_list_sessions), patch.object( + obj, "_start_session", return_value=mock_start_session + ): + # Call the session_id property and assert that it returns the expected value + assert obj.session_id == "session123" + + # Call the _list_sessions and _start_session methods to ensure they were called + obj._list_sessions.assert_called_once() + obj._start_session.assert_not_called() # since _list_sessions return value is not None + + # Call the session_id property again to ensure it still returns the expected value after + # the mock context is complete + assert obj.session_id == "session123" From 233ab98bd3543aacc01d1289fa7d6912653ec86a Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:02:34 -0300 Subject: [PATCH 06/75] Added docs --- dbt/adapters/athena/python_submissions.py | 226 +++++++++++++++++----- 1 file changed, 181 insertions(+), 45 deletions(-) diff --git a/dbt/adapters/athena/python_submissions.py b/dbt/adapters/athena/python_submissions.py index 8ef7b066..d1ba83dc 100644 --- a/dbt/adapters/athena/python_submissions.py +++ b/dbt/adapters/athena/python_submissions.py @@ -18,6 +18,34 @@ class AthenaPythonJobHelper(PythonJobHelper): + """ + A helper class for executing Python jobs on AWS Athena. + + This class extends the base `PythonJobHelper` class and provides additional functionality + specific to executing jobs on Athena. It takes a parsed model and credentials as inputs + during initialization, and provides methods for executing Athena jobs, setting timeout, + polling interval, region name, AWS profile name, and Spark work group. + + Args: + parsed_model (Dict): A dictionary representing the parsed model of the Athena job. + It should contain keys such as 'alias' for job identifier and 'schema' for + job schema. + credentials (AthenaCredentials): An instance of the `AthenaCredentials` class + containing AWS credentials for accessing Athena. + + Attributes: + identifier (str): A string representing the alias or identifier of the Athena job. + schema (str): A string representing the schema of the Athena job. + parsed_model (Dict): A dictionary representing the parsed model of the Athena job. + timeout (int): An integer representing the timeout value in seconds for the Athena job. + polling_interval (int): An integer representing the polling interval in seconds for + checking the status of the Athena job. + region_name (str): A string representing the AWS region name for executing the Athena job. + profile_name (str): A string representing the AWS profile name for accessing Athena. + spark_work_group (str): A string representing the Spark work group for executing the Athena job. + + """ + def __init__(self, parsed_model: Dict, credentials: AthenaCredentials) -> None: self.identifier = parsed_model["alias"] self.schema = parsed_model["schema"] @@ -31,6 +59,16 @@ def __init__(self, parsed_model: Dict, credentials: AthenaCredentials) -> None: @property @lru_cache() def session_id(self) -> str: + """ + Get the session ID. + + This function retrieves the session ID from the stored session information. If session information + is not available, a new session is started and its session ID is returned. + + Returns: + str: The session ID. + + """ session_info = self._list_sessions() if session_info is None: return self._start_session().get("SessionId") @@ -39,67 +77,150 @@ def session_id(self) -> str: @property @lru_cache() def athena_client(self) -> Any: - return boto3.client("athena") + """ + Get the AWS Athena client. + + This function returns an AWS Athena client object that can be used to interact with the Athena service. + The client is created using the region name and profile name provided during object instantiation. + + Returns: + Any: The Athena client object. + + """ + return boto3.session.Session(region_name=self.region_name, profile_name=self.profile_name).client("athena") def get_timeout(self) -> int: + """ + Get the timeout value. + + This function retrieves the timeout value from the parsed model's configuration. If the timeout value + is not defined, it falls back to the default timeout value. If the retrieved timeout value is less than or + equal to 0, a ValueError is raised as timeout must be a positive integer. + + Returns: + int: The timeout value in seconds. + + Raises: + ValueError: If the timeout value is not a positive integer. + + """ timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT) if timeout <= 0: raise ValueError("Timeout must be a positive integer") return timeout def _list_sessions(self) -> dict: - try: - response = self.athena_client.list_sessions( - WorkGroup=self.spark_work_group, MaxResults=1, StateFilter="IDLE" - ) - if len(response.get("Sessions")) == 0 or response.get("Sessions") is None: - return None - return response.get("Sessions")[0] - except Exception: - raise + """ + List Athena sessions. + + This function sends a request to the Athena service to list the sessions in the specified Spark workgroup. + It filters the sessions by state, only returning the first session that is in IDLE state. If no idle sessions + are found or if an error occurs, None is returned. + + Returns: + dict: The session information dictionary if an idle session is found, None otherwise. + + """ + response = self.athena_client.list_sessions(WorkGroup=self.spark_work_group, MaxResults=1, StateFilter="IDLE") + if len(response.get("Sessions")) == 0 or response.get("Sessions") is None: + return None + return response.get("Sessions")[0] def _start_session(self) -> dict: - try: - response = self.athena_client.start_session( - WorkGroup=self.spark_work_group, - EngineConfiguration={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}, - ) - if response["State"] != "IDLE": - self._poll_until_session_creation(response["SessionId"]) - return response - except Exception: - raise + """ + Start an Athena session. + + This function sends a request to the Athena service to start a session in the specified Spark workgroup. + It configures the session with specific engine configurations. If the session state is not IDLE, the function + polls until the session creation is complete. The response containing session information is returned. + + Returns: + dict: The session information dictionary. + + """ + response = self.athena_client.start_session( + WorkGroup=self.spark_work_group, + EngineConfiguration={"CoordinatorDpuSize": 1, "MaxConcurrentDpus": 2, "DefaultExecutorDpuSize": 1}, + ) + if response["State"] != "IDLE": + self._poll_until_session_creation(response["SessionId"]) + return response def submit(self, compiled_code: str) -> dict: - try: - calculation_execution_id = self.athena_client.start_calculation_execution( - SessionId=self.session_id, CodeBlock=compiled_code.lstrip() - )["CalculationExecutionId"] - logger.debug(f"Submitted calculation execution id {calculation_execution_id}") - execution_status = self._poll_until_execution_completion(calculation_execution_id) - logger.debug(f"Received execution status {execution_status}") - if execution_status == "COMPLETED": - result_s3_uri = self.athena_client.get_calculation_execution( - CalculationExecutionId=calculation_execution_id - )["Result"]["ResultS3Uri"] - return result_s3_uri - else: - raise DbtRuntimeError(f"python model run ended in state {execution_status}") - except Exception: - raise + """ + Submit a calculation to Athena. + + This function submits a calculation to Athena for execution using the provided compiled code. + It starts a calculation execution with the current session ID and the compiled code as the code block. + The function then polls until the calculation execution is completed, and retrieves the result S3 URI. + If the execution is successful and completed, the result S3 URI is returned. Otherwise, a DbtRuntimeError + is raised with the execution status. + + Args: + compiled_code (str): The compiled code to submit for execution. + + Returns: + dict: The result S3 URI if the execution is successful and completed. + + Raises: + DbtRuntimeError: If the execution ends in a state other than "COMPLETED". + + """ + calculation_execution_id = self.athena_client.start_calculation_execution( + SessionId=self.session_id, CodeBlock=compiled_code.lstrip() + )["CalculationExecutionId"] + logger.debug(f"Submitted calculation execution id {calculation_execution_id}") + execution_status = self._poll_until_execution_completion(calculation_execution_id) + logger.debug(f"Received execution status {execution_status}") + if execution_status == "COMPLETED": + result_s3_uri = self.athena_client.get_calculation_execution( + CalculationExecutionId=calculation_execution_id + )["Result"]["ResultS3Uri"] + return result_s3_uri + else: + raise DbtRuntimeError(f"python model run ended in state {execution_status}") def _terminate_session(self) -> dict: - try: - session_status = self.athena_client.get_session_status(SessionId=self.session_id)["Status"] - if session_status["State"] in ["IDLE", "BUSY"] and ( - session_status["StartDateTime"] - datetime.now(tz=timezone.utc) > timedelta(seconds=self.timeout) - ): - logger.debug(f"Terminating session: {self.session_id}") - self.athena_client.terminate_session(SessionId=self.session_id) - except Exception: - raise + """ + Terminate the current Athena session. + + This function terminates the current Athena session if it is in IDLE or BUSY state and has exceeded the + configured timeout period. It retrieves the session status, and if the session state is IDLE or BUSY and the + duration since the session start time exceeds the timeout period, the session is terminated. The session ID is + used to terminate the session via the Athena client. + + Returns: + dict: The response from the Athena client after terminating the session. + + """ + session_status = self.athena_client.get_session_status(SessionId=self.session_id)["Status"] + if session_status["State"] in ["IDLE", "BUSY"] and ( + session_status["StartDateTime"] - datetime.now(tz=timezone.utc) > timedelta(seconds=self.timeout) + ): + logger.debug(f"Terminating session: {self.session_id}") + self.athena_client.terminate_session(SessionId=self.session_id) def _poll_until_execution_completion(self, calculation_execution_id): + """ + Poll the status of a calculation execution until it is completed, failed, or cancelled. + + This function polls the status of a calculation execution identified by the given `calculation_execution_id` + until it is completed, failed, or cancelled. It uses the Athena client to retrieve the status of the execution + and checks if the state is one of "COMPLETED", "FAILED", or "CANCELLED". If the execution is not yet completed, + the function sleeps for a certain polling interval, which starts with the value of `self.polling_interval` and + doubles after each iteration until it reaches the `self.timeout` period. If the execution does not complete + within the timeout period, a `DbtRuntimeError` is raised. + + Args: + calculation_execution_id (str): The ID of the calculation execution to poll. + + Returns: + str: The final state of the calculation execution, which can be one of "COMPLETED", "FAILED" or "CANCELLED". + + Raises: + DbtRuntimeError: If the calculation execution does not complete within the timeout period. + + """ polling_interval = self.polling_interval while True: execution_status = self.athena_client.get_calculation_execution_status( @@ -115,6 +236,20 @@ def _poll_until_execution_completion(self, calculation_execution_id): ) def _poll_until_session_creation(self, session_id): + """ + Polls the status of an Athena session creation until it is completed or reaches the timeout. + + Args: + session_id (str): The ID of the session being created. + + Returns: + str: The final status of the session, which will be "IDLE" if the session creation is successful. + + Raises: + DbtRuntimeError: If the session creation fails, is terminated, or degrades during polling. + DbtRuntimeError: If the session does not become IDLE within the specified timeout. + + """ polling_interval = self.polling_interval while True: creation_status = self.athena_client.get_session_status(SessionId=session_id)["Status"]["State"] @@ -128,4 +263,5 @@ def _poll_until_session_creation(self, session_id): raise DbtRuntimeError(f"Session {session_id} did not create within {self.timeout} seconds.") def __del__(self) -> None: + """Teardown for the class.""" self._terminate_session() From d2ad09e5ba35739a1076bfd659d1028590c855b3 Mon Sep 17 00:00:00 2001 From: Avinash-1394 <43074786+Avinash-1394@users.noreply.github.com> Date: Tue, 4 Apr 2023 23:14:12 -0300 Subject: [PATCH 07/75] Updated README --- README.md | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 280ead04..c143fed0 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ * Support two incremental update strategies: `insert_overwrite` and `append` * Does **not** support the use of `unique_key` * Supports [snapshots][snapshots] -* Does not support [Python models][python-models] +* Supports [Python models][python-models] * Does not support [persist docs][persist-docs] for views [seeds]: https://docs.getdbt.com/docs/building-a-dbt-project/seeds @@ -73,6 +73,7 @@ A dbt profile can be configured to run against AWS Athena using the following co | work_group | Identifier of Athena workgroup | Optional | `my-custom-workgroup` | | num_retries | Number of times to retry a failing query | Optional | `3` | | lf_tags | Default lf tags to apply to any database created by dbt | Optional | `{"origin": "dbt", "team": "analytics"}` | +| spark_work_group | Identifier of athena spark workgroup | Optional | `my-spark-workgroup` | **Example profiles.yml entry:** ```yaml @@ -89,6 +90,7 @@ athena: database: awsdatacatalog aws_profile_name: my-profile work_group: my-workgroup + spark_work_group: my-spark-workgroup lf_tags: origin: dbt team: analytics @@ -373,6 +375,35 @@ The only way, from a dbt perspective, is to do a full-refresh of the incremental * Snapshot does not support dropping columns from the source table. If you drop a column make sure to drop the column from the snapshot as well. Another workaround is to NULL the column in the snapshot definition to preserve history +### Python Models + +The adapter supports python models using [`spark`](https://docs.aws.amazon.com/athena/latest/ug/notebooks-spark.html). + +#### Prerequisites + +* A spark enabled work group created in athena +* Spark execution role granted access to Athena, Glue and S3 +* The spark work group is added to the ~/.dbt/profiles.yml file and the profile is referenced in dbt_project.yml + +#### Example model + +```python +import pandas as pd + + +def model(dbt, session): + dbt.config(materialized="table") + + model_df = pd.DataFrame({"A": [1, 2, 3, 4]}) + + return model_df +``` + +#### Known issues in python models + +* Incremental models do not fully utilize spark capabilities. They depend on existing sql based logic. +* Snapshots materializations are not supported. + ### Contributing This connector works with Python from 3.7 to 3.11. From fed0dfacce8e816be7adfd0b4b03b66b33bfd8b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9my=20Guiselin?= <9251353+Jrmyy@users.noreply.github.com> Date: Wed, 5 Apr 2023 17:51:58 +0200 Subject: [PATCH 08/75] fix: return empty if table does not exist in get_columns (#199) --- dbt/adapters/athena/impl.py | 10 +++++++++- tests/unit/test_adapter.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 53b69cdd..e34ebb32 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -696,7 +696,15 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[Column]: with boto3_client_lock: glue_client = client.session.client("glue", region_name=client.region_name, config=get_boto3_config()) - table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)["Table"] + try: + table = glue_client.get_table(DatabaseName=relation.schema, Name=relation.identifier)["Table"] + except ClientError as e: + if e.response["Error"]["Code"] == "EntityNotFoundException": + logger.debug("table not exist, catching the error") + return [] + else: + logger.error(e) + raise e columns = [c for c in table["StorageDescriptor"]["Columns"] if self._is_current_column(c)] partition_keys = table.get("PartitionKeys", []) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 01685291..4a7d10d9 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -848,6 +848,21 @@ def test_get_columns_in_relation(self): Column("dt", "date"), ] + @mock_athena + @mock_glue + def test_get_columns_in_relation_not_found_table(self): + self.mock_aws_service.create_data_catalog() + self.mock_aws_service.create_database() + self.adapter.acquire_connection("dummy") + columns = self.adapter.get_columns_in_relation( + self.adapter.Relation.create( + database=DATA_CATALOG_NAME, + schema=DATABASE_NAME, + identifier="tbl_name", + ) + ) + assert columns == [] + @pytest.mark.parametrize( "response,database,table,columns,lf_tags,expected", [ From f5c816774e0588004bddcd82da604d9830e243ad Mon Sep 17 00:00:00 2001 From: nicor88 <6278547+nicor88@users.noreply.github.com> Date: Thu, 6 Apr 2023 11:40:16 +0200 Subject: [PATCH 09/75] fix: check for empty workgroup in profiles (#194) --- dbt/adapters/athena/impl.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index e34ebb32..535378f3 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -149,13 +149,14 @@ def get_work_group_output_location(self) -> Optional[str]: with boto3_client_lock: athena_client = client.session.client("athena", region_name=client.region_name, config=get_boto3_config()) - work_group = athena_client.get_work_group(WorkGroup=creds.work_group) - return ( - work_group.get("WorkGroup", {}) - .get("Configuration", {}) - .get("ResultConfiguration", {}) - .get("OutputLocation") - ) + if creds.work_group: + work_group = athena_client.get_work_group(WorkGroup=creds.work_group) + return ( + work_group.get("WorkGroup", {}) + .get("Configuration", {}) + .get("ResultConfiguration", {}) + .get("OutputLocation") + ) @available def s3_table_prefix(self, s3_data_dir: Optional[str]) -> str: From 785ba5fd9eb880e695360efd3ff1a9b1a8f05bf1 Mon Sep 17 00:00:00 2001 From: Mattia <5013654+mattiamatrix@users.noreply.github.com> Date: Thu, 6 Apr 2023 12:03:23 +0100 Subject: [PATCH 10/75] chore: add credits section (#201) --- README.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c143fed0..b2424498 100644 --- a/README.md +++ b/README.md @@ -446,8 +446,25 @@ make test [conventionalcommits](https://www.conventionalcommits.org). * Pull request body should describe _motivation_. -### Helpful Resources +## Credits +The following acknowledges the Maintainers for this repository, those who have Contributed to this repository (via bug reports, code, design, ideas, project management, translation, testing, etc.), and any other References utilized. +### Maintainers +The following individuals are responsible for curating the list of issues, responding to pull requests, and ensuring regular releases happen. + +* [nicor88](https://github.com/nicor88) +* [Jrmyy](https://github.com/Jrmyy) +* [jessedobbelaere](https://github.com/jessedobbelaere) +* [mattiamatrix](https://github.com/mattiamatrix) +* [thenaturalist](https://github.com/thenaturalist) + +### Contributors +Thank you to all the people who have already contributed to this repository via bug reports, code, design, ideas, project management, translation, testing, etc. + +* [Tomme](https://github.com/Tomme) - Wrote the initial version. +* [Lemiffe](https://github.com/lemiffe) - Logo design. + +## Resources * [Athena CREATE TABLE AS](https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html) * [dbt-labs/dbt-core](https://github.com/dbt-labs/dbt-core) * [laughingman7743/PyAthena](https://github.com/laughingman7743/PyAthena) From ee8bf9581d76858ae2508fd545017808a18bf70a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9my=20Guiselin?= <9251353+Jrmyy@users.noreply.github.com> Date: Thu, 6 Apr 2023 15:25:49 +0200 Subject: [PATCH 11/75] fix: enable database in policy to support cross-account queries (#200) Co-authored-by: nicor88 <6278547+nicor88@users.noreply.github.com> --- dbt/adapters/athena/relation.py | 2 +- tests/unit/test_relation.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py index 887158c7..414d2dc9 100644 --- a/dbt/adapters/athena/relation.py +++ b/dbt/adapters/athena/relation.py @@ -6,7 +6,7 @@ @dataclass class AthenaIncludePolicy(Policy): - database: bool = False + database: bool = True schema: bool = True identifier: bool = True diff --git a/tests/unit/test_relation.py b/tests/unit/test_relation.py index 21f64973..81872224 100644 --- a/tests/unit/test_relation.py +++ b/tests/unit/test_relation.py @@ -12,7 +12,7 @@ def test_render_hive_uses_hive_style_quotation(self): database=DATA_CATALOG_NAME, schema=DATABASE_NAME, ) - assert relation.render_hive() == f"`{DATABASE_NAME}`.`{TABLE_NAME}`" + assert relation.render_hive() == f"`{DATA_CATALOG_NAME}`.`{DATABASE_NAME}`.`{TABLE_NAME}`" def test_render_hive_resets_quote_character_after_call(self): relation = AthenaRelation.create( @@ -21,7 +21,7 @@ def test_render_hive_resets_quote_character_after_call(self): schema=DATABASE_NAME, ) relation.render_hive() - assert relation.render() == f'"{DATABASE_NAME}"."{TABLE_NAME}"' + assert relation.render() == f'"{DATA_CATALOG_NAME}"."{DATABASE_NAME}"."{TABLE_NAME}"' def test_render_pure_resets_quote_character_after_call(self): relation = AthenaRelation.create( @@ -29,4 +29,4 @@ def test_render_pure_resets_quote_character_after_call(self): database=DATA_CATALOG_NAME, schema=DATABASE_NAME, ) - assert relation.render_pure() == f"{DATABASE_NAME}.{TABLE_NAME}" + assert relation.render_pure() == f"{DATA_CATALOG_NAME}.{DATABASE_NAME}.{TABLE_NAME}" From 414c739336db45b481881d6f8e885971c8071b49 Mon Sep 17 00:00:00 2001 From: Serhii Dimchenko <39801237+svdimchenko@users.noreply.github.com> Date: Fri, 7 Apr 2023 09:43:23 +0200 Subject: [PATCH 12/75] fix: glue column types (#196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: nicor88 <6278547+nicor88@users.noreply.github.com> Co-authored-by: Jérémy Guiselin <9251353+Jrmyy@users.noreply.github.com> --- dbt/adapters/athena/column.py | 55 +++++++++++++++++++++++++++++++++ dbt/adapters/athena/impl.py | 44 +++++++++++++++++--------- dbt/adapters/athena/relation.py | 9 ++++++ tests/unit/test_adapter.py | 16 +++++----- 4 files changed, 101 insertions(+), 23 deletions(-) create mode 100644 dbt/adapters/athena/column.py diff --git a/dbt/adapters/athena/column.py b/dbt/adapters/athena/column.py new file mode 100644 index 00000000..e5dfdadd --- /dev/null +++ b/dbt/adapters/athena/column.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass + +from dbt.adapters.athena.relation import TableType +from dbt.adapters.base.column import Column +from dbt.exceptions import DbtRuntimeError + + +@dataclass +class AthenaColumn(Column): + table_type: TableType = TableType.TABLE + + def is_iceberg(self) -> bool: + return self.table_type == TableType.ICEBERG + + def is_string(self) -> bool: + return self.dtype.lower() in {"varchar", "string"} + + def is_binary(self) -> bool: + return self.dtype.lower() in {"binary", "varbinary"} + + def is_timestamp(self) -> bool: + return self.dtype.lower() in {"timestamp"} + + @classmethod + def string_type(cls, size: int) -> str: + return f"varchar({size})" if size > 0 else "varchar" + + @classmethod + def binary_type(cls) -> str: + return "varbinary" + + def timestamp_type(self) -> str: + if self.is_iceberg(): + return "timestamp(6)" + return "timestamp" + + def string_size(self) -> int: + if not self.is_string(): + raise DbtRuntimeError("Called string_size() on non-string field!") + if not self.char_size: + # Handle error: '>' not supported between instances of 'NoneType' and 'NoneType' for union relations macro + return 0 + return self.char_size + + @property + def data_type(self) -> str: + if self.is_string(): + return self.string_type(self.string_size()) + elif self.is_numeric(): + return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) + elif self.is_binary(): + return self.binary_type() + elif self.is_timestamp(): + return self.timestamp_type() + return self.dtype diff --git a/dbt/adapters/athena/impl.py b/dbt/adapters/athena/impl.py index 535378f3..e42b6f16 100755 --- a/dbt/adapters/athena/impl.py +++ b/dbt/adapters/athena/impl.py @@ -12,11 +12,22 @@ from botocore.exceptions import ClientError from dbt.adapters.athena import AthenaConnectionManager +from dbt.adapters.athena.column import AthenaColumn from dbt.adapters.athena.config import get_boto3_config +<<<<<<< HEAD from dbt.adapters.athena.python_submissions import AthenaPythonJobHelper from dbt.adapters.athena.relation import AthenaRelation, AthenaSchemaSearchMap from dbt.adapters.athena.utils import clean_sql_comment from dbt.adapters.base import Column, PythonJobHelper, available +======= +from dbt.adapters.athena.relation import ( + AthenaRelation, + AthenaSchemaSearchMap, + TableType, +) +from dbt.adapters.athena.utils import clean_sql_comment +from dbt.adapters.base import available +>>>>>>> 6f41803 (fix: glue column types (#196)) from dbt.adapters.base.relation import BaseRelation, InformationSchema from dbt.adapters.sql import SQLAdapter from dbt.contracts.connection import AdapterResponse @@ -35,13 +46,13 @@ class AthenaAdapter(SQLAdapter): Relation = AthenaRelation relation_type_map = { - "EXTERNAL_TABLE": "table", - "MANAGED_TABLE": "table", - "VIRTUAL_VIEW": "view", - "table": "table", - "view": "view", - "cte": "cte", - "materializedview": "materializedview", + "EXTERNAL_TABLE": TableType.TABLE, + "MANAGED_TABLE": TableType.TABLE, + "VIRTUAL_VIEW": TableType.VIEW, + "table": TableType.TABLE, + "view": TableType.VIEW, + "cte": TableType.CTE, + "materializedview": TableType.MATERIALIZED_VIEW, } @classmethod @@ -366,7 +377,7 @@ def _get_one_table_for_catalog(self, table: dict, database: str) -> list: "table_database": database, "table_schema": table["DatabaseName"], "table_name": table["Name"], - "table_type": self.relation_type_map[table["TableType"]], + "table_type": self.relation_type_map[table["TableType"]].value, "table_comment": table.get("Parameters", {}).get("comment", table.get("Description", "")), } return [ @@ -486,7 +497,7 @@ def list_relations_without_caching( return relations @available - def get_table_type(self, db_name, table_name): + def get_table_type(self, db_name, table_name) -> TableType: conn = self.connections.get_thread_connection() client = conn.handle @@ -495,17 +506,17 @@ def get_table_type(self, db_name, table_name): try: response = glue_client.get_table(DatabaseName=db_name, Name=table_name) - _type = self.relation_type_map.get(response.get("Table", {}).get("TableType", "Table")) + _type = self.relation_type_map.get(response.get("Table", {}).get("TableType")) _specific_type = response.get("Table", {}).get("Parameters", {}).get("table_type", "") if _specific_type.lower() == "iceberg": - _type = "iceberg_table" + _type = TableType.ICEBERG if _type is None: raise ValueError("Table type cannot be None") - logger.debug("table_name : " + table_name) - logger.debug("table type : " + _type) + logger.debug(f"table_name : {table_name}") + logger.debug(f"table type : {_type}") return _type @@ -690,7 +701,7 @@ def _is_current_column(col: dict) -> bool: return True @available - def get_columns_in_relation(self, relation: AthenaRelation) -> List[Column]: + def get_columns_in_relation(self, relation: AthenaRelation) -> List[AthenaColumn]: conn = self.connections.get_thread_connection() client = conn.handle @@ -706,10 +717,13 @@ def get_columns_in_relation(self, relation: AthenaRelation) -> List[Column]: else: logger.error(e) raise e + table_type = self.get_table_type(relation.schema, relation.identifier) columns = [c for c in table["StorageDescriptor"]["Columns"] if self._is_current_column(c)] partition_keys = table.get("PartitionKeys", []) logger.debug(f"Columns in relation {relation.identifier}: {columns + partition_keys}") - return [Column(c["Name"], c["Type"]) for c in columns + partition_keys] + return [ + AthenaColumn(column=c["Name"], dtype=c["Type"], table_type=table_type) for c in columns + partition_keys + ] diff --git a/dbt/adapters/athena/relation.py b/dbt/adapters/athena/relation.py index 414d2dc9..0cf51947 100644 --- a/dbt/adapters/athena/relation.py +++ b/dbt/adapters/athena/relation.py @@ -1,9 +1,18 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Dict, Optional, Set from dbt.adapters.base.relation import BaseRelation, InformationSchema, Policy +class TableType(Enum): + TABLE = "table" + VIEW = "view" + CTE = "cte" + MATERIALIZED_VIEW = "materializedview" + ICEBERG = "iceberg_table" + + @dataclass class AthenaIncludePolicy(Policy): database: bool = True diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 4a7d10d9..d69bdc69 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -10,9 +10,9 @@ from dbt.adapters.athena import AthenaAdapter from dbt.adapters.athena import Plugin as AthenaPlugin +from dbt.adapters.athena.column import AthenaColumn from dbt.adapters.athena.connections import AthenaCursor, AthenaParameterFormatter -from dbt.adapters.athena.relation import AthenaRelation -from dbt.adapters.base import Column +from dbt.adapters.athena.relation import AthenaRelation, TableType from dbt.clients import agate_helper from dbt.contracts.connection import ConnectionState from dbt.contracts.files import FileHash @@ -436,7 +436,7 @@ def test__get_relation_type_table(self, aws_credentials): self.mock_aws_service.create_table("test_table") self.adapter.acquire_connection("dummy") table_type = self.adapter.get_table_type(DATABASE_NAME, "test_table") - assert table_type == "table" + assert table_type == TableType.TABLE @mock_glue @mock_s3 @@ -459,7 +459,7 @@ def test__get_relation_type_view(self, aws_credentials): self.mock_aws_service.create_view("test_view") self.adapter.acquire_connection("dummy") table_type = self.adapter.get_table_type(DATABASE_NAME, "test_view") - assert table_type == "view" + assert table_type == TableType.VIEW @mock_glue @mock_s3 @@ -470,7 +470,7 @@ def test__get_relation_type_iceberg(self, aws_credentials): self.mock_aws_service.create_iceberg_table("test_iceberg") self.adapter.acquire_connection("dummy") table_type = self.adapter.get_table_type(DATABASE_NAME, "test_iceberg") - assert table_type == "iceberg_table" + assert table_type == TableType.ICEBERG def _test_list_relations_without_caching(self, schema_relation): self.adapter.acquire_connection("dummy") @@ -843,9 +843,9 @@ def test_get_columns_in_relation(self): ) ) assert columns == [ - Column("id", "string"), - Column("country", "string"), - Column("dt", "date"), + AthenaColumn(column="id", dtype="string", table_type=TableType.TABLE), + AthenaColumn(column="country", dtype="string", table_type=TableType.TABLE), + AthenaColumn(column="dt", dtype="date", table_type=TableType.TABLE), ] @mock_athena From 6ad67e7cbc23bf5cee76193a7db3cd2592b52b86 Mon Sep 17 00:00:00 2001 From: "allcontributors[bot]" <46447321+allcontributors[bot]@users.noreply.github.com> Date: Fri, 7 Apr 2023 18:01:13 +0100 Subject: [PATCH 13/75] docs: add nicor88 as a contributor for code (#205) Co-authored-by: allcontributors[bot] <46447321+allcontributors[bot]@users.noreply.github.com> --- .all-contributorsrc | 26 ++++++++++++++++++++++++++ README.md | 25 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 .all-contributorsrc diff --git a/.all-contributorsrc b/.all-contributorsrc new file mode 100644 index 00000000..b499a8b1 --- /dev/null +++ b/.all-contributorsrc @@ -0,0 +1,26 @@ +{ + "files": [ + "README.md" + ], + "imageSize": 100, + "commit": false, + "commitConvention": "angular", + "contributors": [ + { + "login": "nicor88", + "name": "nicor88", + "avatar_url": "https://avatars.githubusercontent.com/u/6278547?v=4", + "profile": "https://github.com/nicor88", + "contributions": [ + "code", + "maintenance" + ] + } + ], + "contributorsPerLine": 7, + "skipCi": true, + "repoType": "github", + "repoHost": "https://github.com", + "projectName": "dbt-athena", + "projectOwner": "dbt-athena" +} diff --git a/README.md b/README.md index b2424498..185db8f4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,7 @@ [![pypi](https://badge.fury.io/py/dbt-athena-community.svg)](https://pypi.org/project/dbt-athena-community/) + +[![All Contributors](https://img.shields.io/badge/all_contributors-1-orange.svg?style=flat-square)](#contributors-) + [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Stats: pepy](https://pepy.tech/badge/dbt-athena-community/month)](https://pepy.tech/project/dbt-athena-community) @@ -468,3 +471,25 @@ Thank you to all the people who have already contributed to this repository via * [Athena CREATE TABLE AS](https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html) * [dbt-labs/dbt-core](https://github.com/dbt-labs/dbt-core) * [laughingman7743/PyAthena](https://github.com/laughingman7743/PyAthena) + +## Contributors ✨ + +Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)): + + + + +
nicor88 💻 🚧 |
+
W5g$V)A88Ehiz#ryoetiW5=$@EBH!H5w#YnHpHfteBsotiZBX#$
zqG2D&zzfiz4i6OZ4liF>8+@;*FzH7W+ke{PzNt?ou*v@~?&Pb)
zn*oEYz9MnDIa)-SWHrVhUVWKvsrnR8A@P0(QlqE@n0Mmk$Besm>w{SuSB0#~gZcmE
zJ;bOv5e1Wjmy=%p6uV>aCuBt2|E$vYbl!`DwTBgA<(Nueik+;`QT
9pgt^4c$?
zZq{!MT25Ym-T~y$ukgtBi{^GzMv2oN{=Q4^g^xC;c0if5F~F-lnNgehQ_IC<3ZcO@
z+242rcR6~w2Hf`AL@yMtI(baH=3DZ+sq7uHaNu>66Uic~{&3MV=z0eIDT&Z;wh8$00$N>Wg(V6B+j`_K=ez9Li78FH+b4A%nALB41TIjb`gx6
zX6W^JVfcx6fpwafSowH`D^_Afgr)dVmR-$>g!Q=f?!*u343hnMjdemE??k=ykLUE%
zA*4JvcEY#3{L7i=Zeb2stNxj*IrBG7=b-0Xpl!MX)dY>U+jyv%aX4>OE!8ayw91;q
zO-hUc3BsTEoz?u>#R69cl_(xH0F$#(w_wi4uL-M+3vd#A`6nVZ