Skip to content

Commit

Permalink
restructure tests
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Feb 5, 2025
1 parent feae7ff commit 8e3af2f
Show file tree
Hide file tree
Showing 12 changed files with 292 additions and 194 deletions.
85 changes: 50 additions & 35 deletions tests/integration/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,19 @@

from tests.integration.utils import execute_query, universql_connection, SIMPLE_QUERY, ALL_COLUMNS_QUERY

def generate_name_variants(name):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
return [lowercase, uppercase, mixed_case, in_quotes]

def generate_select_statement_combos(table, schema = None, database = None):
select_statements = []
table_variants = generate_name_variants(table)

if database is not None:
database_variants = generate_name_variants(database)
schema_variants = generate_name_variants(schema)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
else:
if schema is not None:
schema_variants = generate_name_variants(schema)
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
else:
for table_variant in table_variants:
select_statements.append(f"SELECT * FROM {table_variant}")
return select_statements
class TestConnectivity:
def test_invalid_auth(self):
with universql_connection(password="invalidPass") as conn:
with pytest.raises(ProgrammingError, match="Incorrect username or password was specified"):
execute_query(conn, "SHOW TABLES LIMIT 1")

with universql_connection(password="invalidPass") as conn:
with pytest.raises(ProgrammingError, match="Incorrect username or password was specified"):
execute_query(conn, "SELECT 1")

with universql_connection(password="invalidPass") as conn:
with pytest.raises(ProgrammingError, match="Incorrect username or password was specified"):
execute_query(conn, "CREATE TEMP TABLE test_table AS SELECT 1 as t; SELECT * FROM test_table;")


class TestSelect:
Expand All @@ -40,12 +26,6 @@ def test_simple_select(self):
universql_result = execute_query(conn, SIMPLE_QUERY)
print(universql_result)

@pytest.mark.skip(reason="Stages are not implemented yet")
def test_from_stage(self):
with universql_connection() as conn:
universql_result = execute_query(conn, "select count(*) from @stage/iceberg_stage")
print(universql_result)

def test_complex_select(self):
with universql_connection() as conn:
universql_result = execute_query(conn, ALL_COLUMNS_QUERY)
Expand Down Expand Up @@ -81,7 +61,43 @@ def test_union(self):
result = execute_query(conn, "select 1 union all select 2")
assert result.num_rows == 2

def test_copy_into(self):

def test_stage(self):
with universql_connection(warehouse=None, database="MY_ICEBERG_JINJAT", schema="TPCH_SF1") as conn:
result = execute_query(conn, """
select * FROM @clickhouse_public_data_stage/ limit 1
""")
# result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv")
assert result.num_rows > 0

def test_copy_into_for_ryan(self):
with universql_connection(snowflake_connection_name='ryan_snowflake', warehouse=None, database="ICEBERG_DB") as conn:

result = execute_query(conn, """
CREATE OR REPLACE TEMPORARY TABLE DEVICE_METADATA_REF (
device_id VARCHAR,
device_name VARCHAR,
device_type VARCHAR,
manufacturer VARCHAR,
model_number VARCHAR,
firmware_version VARCHAR,
installation_date DATE,
location_id VARCHAR,
location_name VARCHAR,
facility_zone VARCHAR,
is_active BOOLEAN,
expected_lifetime_months INT,
maintenance_interval_days INT,
last_maintenance_date DATE
);
COPY INTO DEVICE_METADATA_REF
FROM @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv
FILE_FORMAT = (SKIP_HEADER = 1);
""")
assert result.num_rows != 0

def test_clickbench(self):
with universql_connection(warehouse=None) as conn:
result = execute_query(conn, """
CREATE TEMP TABLE hits2 AS SELECT
Expand Down Expand Up @@ -196,5 +212,4 @@ def test_copy_into(self):
""")

result = execute_query(conn, "select count(*) from hits2")

assert result.num_rows == 10
12 changes: 5 additions & 7 deletions tests/integration/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

class TestCreate:
def test_create_iceberg_table(self):
EXTERNAL_VOLUME_NAME = os.getenv("EXTERNAL_VOLUME_NAME")
external_volume = os.getenv("PYTEST_EXTERNAL_VOLUME")
if external_volume is None:
pytest.skip("No external volume provided, set PYTEST_EXTERNAL_VOLUME")

with universql_connection(warehouse=None) as conn:
execute_query(conn, f"""
CREATE OR REPLACE ICEBERG TABLE test_iceberg_table
external_volume = {EXTERNAL_VOLUME_NAME}
external_volume = {external_volume}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'test_iceberg_table'
AS {SIMPLE_QUERY}
Expand All @@ -32,8 +34,4 @@ def test_create_temp_table(self):
def test_create_native_table(self):
with universql_connection(warehouse=None) as conn:
with pytest.raises(ProgrammingError, match="DuckDB can't create native Snowflake tables"):
execute_query(conn, f"CREATE TABLE test_native_table AS {SIMPLE_QUERY}")

@pytest.mark.skip(reason="not implemented")
def test_create_stage(self):
pass
execute_query(conn, f"CREATE TABLE test_native_table AS {SIMPLE_QUERY}")
120 changes: 75 additions & 45 deletions tests/integration/object_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,57 @@
from itertools import product

import pytest

from tests.integration.utils import execute_query, universql_connection, snowflake_connection, generate_select_statement_combos
from dotenv import load_dotenv
from tests.integration.utils import execute_query, universql_connection
import os


def generate_name_variants(name):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
return [lowercase, uppercase, mixed_case, in_quotes]

def generate_select_statement_combos(sets_of_identifiers, connected_db=None, connected_schema=None):
select_statements = []
for set in sets_of_identifiers:
set_of_select_statements = []
database = set.get("database")
schema = set.get("schema")
table = set.get("table")
if table is not None:
table_variants = generate_name_variants(table)
if database == connected_db and schema == connected_schema:
for table_variant in table_variants:
set_of_select_statements.append(f"SELECT * FROM {table_variant}")
else:
raise Exception("No table name provided for a select statement combo.")

if schema is not None:
schema_variants = generate_name_variants(schema)
if database == connected_db:
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
elif database is not None:
raise Exception("You must provide a schema name if you provide a database name.")

if database is not None:
database_variants = generate_name_variants(database)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
select_statements = select_statements + set_of_select_statements

return select_statements

class TestObjectIdentifiers:
def test_setup(self):
EXTERNAL_VOLUME_NAME = os.getenv("EXTERNAL_VOLUME_NAME")

with snowflake_connection() as conn:
cursor = conn.cursor()
cursor.execute(, f"""
execute immediate $$
begin
CREATE DATABASE IF NOT EXISTS universql1;
CREATE DATABASE IF NOT EXISTS universql2;
CREATE SCHEMA IF NOT EXISTS universql1.same_schema;
CREATE SCHEMA IF NOT EXISTS universql1.different_schema;
CREATE SCHEMA IF NOT EXISTS universql2.another_schema;
CREATE ICEBERG TABLE IF NOT EXISTS universql1.same_schema.dim_devices("1" int)
external_volume = {EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'universql1.same_schema.dim_devices'
AS select 1;
CREATE ICEBERG TABLE IF NOT EXISTS universql1.different_schema.different_dim_devices("1" int)
external_volume = {EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'universql1.different_schema.different_dim_devices'
AS select 1;
CREATE ICEBERG TABLE IF NOT EXISTS universql2.another_schema.another_dim_devices("1" int)
external_volume = {EXTERNAL_VOLUME_NAME}
catalog = 'SNOWFLAKE'
BASE_LOCATION = ' universql2.another_schema.another_dim_devices'
AS select 1;
end;
$$
""",


# requires the following:
# a connection's file ~/.snowflake/connections.toml
# a connection in that file called "integration_test_universql" specifying that the warehouse is none
# the connected user must be the same as for test_setup
def test_querying_in_connected_db_and_schema(self):
external_volume = os.getenv("PYTEST_EXTERNAL_VOLUME")
if external_volume is None:
pytest.skip("No external volume provided, set PYTEST_EXTERNAL_VOLUME")

connected_db = "universql1"
connected_schema = "same_schema"

Expand All @@ -70,9 +76,33 @@ def test_querying_in_connected_db_and_schema(self):
select_statements = generate_select_statement_combos(combos, connected_db, connected_schema)
successful_queries = []
failed_queries = []

# create toml file
with universql_connection(database=connected_db, schema=connected_schema) as conn:
execute_query(conn, f"""
CREATE DATABASE IF NOT EXISTS universql1;
CREATE DATABASE IF NOT EXISTS universql2;
CREATE SCHEMA IF NOT EXISTS universql1.same_schema;
CREATE SCHEMA IF NOT EXISTS universql1.different_schema;
CREATE SCHEMA IF NOT EXISTS universql2.another_schema;
CREATE ICEBERG TABLE IF NOT EXISTS universql1.same_schema.dim_devices("1" int)
external_volume = {external_volume}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'universql1.same_schema.dim_devices'
AS select 1;
CREATE ICEBERG TABLE IF NOT EXISTS universql1.different_schema.different_dim_devices("1" int)
external_volume = {external_volume}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'universql1.different_schema.different_dim_devices'
AS select 1;
CREATE ICEBERG TABLE IF NOT EXISTS universql2.another_schema.another_dim_devices("1" int)
external_volume = {external_volume}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'universql2.another_schema.another_dim_devices'
AS select 1;
""")

for query in select_statements:
try:
execute_query(conn, query)
Expand All @@ -84,4 +114,4 @@ def test_querying_in_connected_db_and_schema(self):
error_message = f"The following {len(failed_queries)} queries failed:"
for query in failed_queries:
error_message = f"{error_message}\n{query}"
pytest.fail(error_message)
pytest.fail(error_message)
52 changes: 2 additions & 50 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,11 @@ def universql_connection(**properties) -> SnowflakeConnection:
account = connection.get('account')
if account in server_cache:
uni_string = {"host": LOCALHOSTCOMPUTING_COM, "port": server_cache[account]} | properties
print(f"Reusing existing server running on port {server_cache[account]} for account {account}")
else:
from universql.main import snowflake
with socketserver.TCPServer(("localhost", 0), None) as s:
free_port = s.server_address[1]
print(f"Reusing existing server running on port {free_port} for account {account}")

def start_universql():
runner = CliRunner()
Expand All @@ -135,6 +135,7 @@ def start_universql():

connect = None
try:
print(snowflake_connection_name, uni_string)
connect = snowflake_connect(connection_name=snowflake_connection_name, **uni_string)
yield connect
finally:
Expand Down Expand Up @@ -186,56 +187,7 @@ def compare_results(snowflake_result: pyarrow.Table, universql_result: pyarrow.T
print("Results match perfectly!")


def generate_name_variants(name, include_blank=False):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
print([lowercase, uppercase, mixed_case, in_quotes])
return [lowercase, uppercase, mixed_case, in_quotes]


def generate_select_statement_combos(sets_of_identifiers, connected_db=None, connected_schema=None):
select_statements = []
for set in sets_of_identifiers:
set_of_select_statements = []
database = set.get("database")
schema = set.get("schema")
table = set.get("table")
if table is not None:
table_variants = generate_name_variants(table)
if database == connected_db and schema == connected_schema:
for table_variant in table_variants:
set_of_select_statements.append(f"SELECT * FROM {table_variant}")
else:
raise Exception("No table name provided for a select statement combo.")

if schema is not None:
schema_variants = generate_name_variants(schema)
if database == connected_db:
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
else:
if database is not None:
raise Exception("You must provide a schema name if you provide a database name.")

if database is not None:
database_variants = generate_name_variants(database)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
select_statements = select_statements + set_of_select_statements
logger.info(f"database: {database}, schema: {schema}, table: {table}")
for statement in set_of_select_statements:
logger.info(statement)
# logger.info(f"database: {database}, schema: {schema}, table: {table}")

return select_statements


def _set_connection_name(connection_dict={}):
snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
logger.info(f"Using the {snowflake_connection_name} connection")
return snowflake_connection_name

18 changes: 9 additions & 9 deletions tests/scratch/sqlglot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

# SELECT ascii(t.$1), ascii(t.$2) FROM 's3://fullpath' (file_format_for_duckdb => myformat) t;

SnowflakeStageTransformer(SnowflakeCatalog())
one = sqlglot.parse_one("SELECT ascii(t.$1), ascii(t.$2) FROM @mystage1 (file_format => myformat) t;", read="snowflake")
two = sqlglot.parse_one("""COPY INTO stg_device_metadata
FROM @iceberg_db.public.landing_stage/initial_objects/
Expand All @@ -28,14 +27,15 @@

class FixTimestampTypes(UniversqlPlugin):

def transform_sql(self, expression, target_executor: Executor):
if isinstance(target_executor, DuckDBExecutor) and isinstance(expression, sqlglot.exp.DataType):
if expression.this.value in ["TIMESTAMPLTZ", "TIMESTAMPTZ"]:
return sqlglot.exp.DataType.build("TIMESTAMPTZ")
if expression.this.value in ["VARIANT"]:
return sqlglot.exp.DataType.build("JSON")
def transform_sql(self, ast, target_executor: Executor):
def fix_timestamp_types(expression):
if isinstance(target_executor, DuckDBExecutor) and isinstance(expression, sqlglot.exp.DataType):
if expression.this.value in ["TIMESTAMPLTZ", "TIMESTAMPTZ"]:
return sqlglot.exp.DataType.build("TIMESTAMPTZ")
if expression.this.value in ["VARIANT"]:
return sqlglot.exp.DataType.build("JSON")

return expression
return ast.transform(fix_timestamp_types)


class RewriteCreateAsIceberg(UniversqlPlugin):
Expand Down Expand Up @@ -111,4 +111,4 @@ def transform_sql(self, expression: Expression, target_executor: Executor) -> Ex
one = sqlglot.parse_one("select * from @test", read="snowflake")
# one = sqlglot.parse_one("select * from 's3://test'", read="snowflake")
# one = sqlglot.parse_one("select to_variant(test) as test from (select 1)", read="snowflake")
# one = sqlglot.parse_one("create table test as select to_variant(test) as test from (select 1)", read="snowflake")
# one = sqlglot.parse_one("create table test as select to_variant(test) as test from (select 1)", read="snowflake")
Loading

0 comments on commit 8e3af2f

Please sign in to comment.