diff --git a/tests/integration/object_identifiers.py b/tests/integration/object_identifiers.py new file mode 100644 index 0000000..78350f0 --- /dev/null +++ b/tests/integration/object_identifiers.py @@ -0,0 +1,158 @@ +import pytest + +from tests.integration.utils import execute_query, dynamic_universql_connection +from dotenv import load_dotenv +import os +import logging +from itertools import product +from _pytest.reports import TestReport + +logger = logging.getLogger(__name__) + +class TestObjectIdentifiers: + + load_dotenv() + + account = os.getenv("TEST_ACCOUNT_IDENTIFIER") + user = os.getenv("TEST_USER") + password = os.getenv("TEST_PASSWORD") + role = os.getenv("TEST_ROLE") + + def test_run_all_tests(self): + tests = [ + _format_test_params("universql1", "same_schema", "dim_devices", "universql1", "same_schema"), + _format_test_params("universql1", "same_schema", "dim_devices", "UNIVERSQL1", "same_schema"), + _format_test_params("universql1", "same_schema", "dim_devices", "UNIVERSQL1", "SAME_SCHEMA"), + _format_test_params("universql1", "same_schema", "dim_devices", "universql1", "SAME_SCHEMA"), + _format_test_params("universql1", "different_schema", "dim_devices", "universql1", "same_schema"), + _format_test_params("universql1", "different_schema", "dim_devices", "UNIVERSQL1", "same_schema"), + _format_test_params("universql1", "different_schema", "dim_devices", "UNIVERSQL1", "SAME_SCHEMA"), + _format_test_params("universql1", "different_schema", "dim_devices", "universql1", "SAME_SCHEMA"), + _format_test_params("universql2", "another_schema", "dim_devices", "universql1", "same_schema"), + _format_test_params("universql2", "another_schema", "dim_devices", "UNIVERSQL1", "same_schema"), + _format_test_params("universql2", "another_schema", "dim_devices", "UNIVERSQL1", "SAME_SCHEMA"), + _format_test_params("universql2", "another_schema", "dim_devices", "universql1", "SAME_SCHEMA"), + ] + + failed_tests = [] + for test in tests: + failures = self.run_test_queries(test["table_db"], test["table_schema"], test["table_name"], test["connected_db"], test["connected_schema"], ) + if failures is not None: + failed_tests.append(failures) + + if len(failed_tests) > 0: + error_messages_array = [] + for failure in failed_tests: + error_message_array = [] + error_message_array.append(f"Connection to database='{failure["connected_db"]}', schema='{failure["connected_schema"]}':") + if len(failure["tables_not_found"]) > 0: + error_message_array.append(f"-Tables not found ({len(failure["tables_not_found"])} queries):") + for unfound_table_query in failure["tables_not_found"]: + error_message_array.append(f" * {unfound_table_query}") + if len(failure["other_errors"]) > 0: + error_message_array.append(f"-Other errors ({len(failure["other_errors"])} queries):") + for other_error_query in failure["other_errors"]: + query, error_message = next(iter(other_error_query.items())) + error_message_array.append(f" * Query: {query}") + error_message_array.append(f" Error: {error_message}") + formatted_error_message = "\n".join(error_message_array) + error_messages_array.append(formatted_error_message) + formatted_error_messages = "\n\n".join(error_messages_array) + logger.info(formatted_error_messages) + pytest.fail(formatted_error_messages) + + def run_test_queries(self, table_db, table_schema, table_name, connected_db, connected_schema): + fully_qualified_queries = _generate_select_statement_combos(table_name, table_schema, table_db) + all_queries = fully_qualified_queries + if connected_db == table_db: + no_db_queries = _generate_select_statement_combos(table_name, table_schema) + all_queries = all_queries + no_db_queries + if connected_schema == table_schema: + no_schema_queries = _generate_select_statement_combos(table_name) + all_queries = all_queries + no_schema_queries + all_queries_no_duplicates = sorted(list(set(all_queries))) + successful_queries = [] + counter = 0 + + connection_params = _generate_usql_connection_params(self.account, self.user, self.password, self.role, connected_db, connected_schema) + + tables_not_found = [] + other_errors = [] + with dynamic_universql_connection(**connection_params) as conn: + for query in all_queries_no_duplicates: + counter += 1 + try: + result = execute_query(conn, query) + successful_queries.append(query) + continue + except Exception as e: + if str(e).startswith("Unable to find location of Iceberg tables."): + tables_not_found.append(query) + else: + other_errors.append({query: e}) + if len(tables_not_found) > 0 or len(other_errors) > 0: + failures_overview = { + "connected_db": connected_db, + "connected_schema": connected_schema, + "tables_not_found": tables_not_found, + "other_errors": other_errors + } + return failures_overview + else: + logger.info(f"Connection to database='{connected_db}', schema='{connected_schema}': All queries PASSED") + return None + +def _generate_name_variants(name, include_blank = False): + lowercase = name.lower() + uppercase = name.upper() + mixed_case = name.capitalize() + in_quotes = '"' + name.upper() + '"' + print([lowercase, uppercase, mixed_case, in_quotes]) + return [lowercase, uppercase, mixed_case, in_quotes] + +def _generate_select_statement_combos(table, schema = None, database = None): + select_statements = [] + table_variants = _generate_name_variants(table) + + if database is not None: + database_variants = _generate_name_variants(database) + schema_variants = _generate_name_variants(schema) + object_name_combos = product(database_variants, schema_variants, table_variants) + for db_name, schema_name, table_name in object_name_combos: + select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}") + else: + if schema is not None: + schema_variants = _generate_name_variants(schema) + object_name_combos = product(schema_variants, table_variants) + for schema_name, table_name in object_name_combos: + select_statements.append(f"SELECT * FROM {schema_name}.{table_name}") + else: + for table_variant in table_variants: + select_statements.append(f"SELECT * FROM {table_variant}") + + return select_statements + +def _generate_usql_connection_params(account, user, password, role, database = None, schema = None): + params = { + "account": account, + "user": user, + "password": password, + "role": role, + "warehouse": "local()", + } + if database is not None: + params["database"] = database + if schema is not None: + params["schema"] = schema + + return params + + +def _format_test_params(table_db, table_schema, table_name, connected_db, connected_schema): + return { + "table_db": table_db, + "table_schema": table_schema, + "table_name": table_name, + "connected_db": connected_db, + "connected_schema": connected_schema, + } \ No newline at end of file diff --git a/tests/integration/universql.metadata.sqlite b/tests/integration/universql.metadata.sqlite new file mode 100644 index 0000000..8876150 Binary files /dev/null and b/tests/integration/universql.metadata.sqlite differ diff --git a/tests/integration/utils.py b/tests/integration/utils.py index b6bdf79..f7322d8 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -4,6 +4,7 @@ import threading from contextlib import contextmanager from typing import Generator +import logging import pyarrow import pytest @@ -14,6 +15,8 @@ from universql.util import LOCALHOSTCOMPUTING_COM +logger = logging.getLogger(__name__) + # Configuration using separate connection strings for direct and proxy connections # export SNOWFLAKE_CONNECTION_STRING="account=xxx;user=xxx;password=xxx;warehouse=xxx;database=xxx;schema=xxx" # export UNIVERSQL_CONNECTION_STRING="warehouse=xxx" @@ -125,6 +128,45 @@ def start_universql(): finally: # Force stop the thread connect.close() +@contextmanager +def dynamic_universql_connection(**properties) -> SnowflakeConnection: + """Create a connection through UniversQL proxy.""" + from universql.main import snowflake + with socketserver.TCPServer(("localhost", 0), None) as s: + free_port = s.server_address[1] + + def start_universql(): + runner = CliRunner() + try: + invoke = runner.invoke(snowflake, + [ + '--account', + properties.get('account'), + '--port', + free_port, + '--catalog', + 'snowflake', + ], + ) + except Exception as e: + pytest.fail(e) + + if invoke.exit_code != 0: + pytest.fail("Unable to start Universql") + + thread = threading.Thread(target=start_universql) + thread.daemon = True + thread.start() + + # with runner.isolated_filesystem(): + uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": free_port} | properties + + try: + connect = snowflake_connect(**uni_string) + yield connect + finally: # Force stop the thread + connect.close() + def execute_query(conn, query: str) -> pyarrow.Table: cur = conn.cursor()