Skip to content

Commit

Permalink
fix the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Feb 5, 2025
1 parent 0ff3627 commit 53a0b30
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 102 deletions.
39 changes: 2 additions & 37 deletions tests/integration/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,6 @@

from tests.integration.utils import execute_query, universql_connection, SIMPLE_QUERY, ALL_COLUMNS_QUERY

def generate_name_variants(name):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
return [lowercase, uppercase, mixed_case, in_quotes]

def generate_select_statement_combos(table, schema = None, database = None):
select_statements = []
table_variants = generate_name_variants(table)

if database is not None:
database_variants = generate_name_variants(database)
schema_variants = generate_name_variants(schema)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
else:
if schema is not None:
schema_variants = generate_name_variants(schema)
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
else:
for table_variant in table_variants:
select_statements.append(f"SELECT * FROM {table_variant}")
return select_statements

class TestConnectivity:
def test_invalid_auth(self):
with universql_connection(password="invalidPass") as conn:
Expand All @@ -54,12 +26,6 @@ def test_simple_select(self):
universql_result = execute_query(conn, SIMPLE_QUERY)
print(universql_result)

@pytest.mark.skip(reason="Stages are not implemented yet")
def test_from_stage(self):
with universql_connection() as conn:
universql_result = execute_query(conn, "select count(*) from @stage/iceberg_stage")
print(universql_result)

def test_complex_select(self):
with universql_connection() as conn:
universql_result = execute_query(conn, ALL_COLUMNS_QUERY)
Expand Down Expand Up @@ -99,8 +65,7 @@ def test_union(self):
def test_stage(self):
with universql_connection(warehouse=None, database="MY_ICEBERG_JINJAT", schema="TPCH_SF1") as conn:
result = execute_query(conn, """
create temp table if not exists clickhouse_public_data as select 1 as t;
copy into clickhouse_public_data FROM @clickhouse_public_data_stage/
select * FROM @clickhouse_public_data_stage/ limit 1
""")
# result = execute_query(conn, "select * from @iceberg_db.public.landing_stage/initial_objects/device_metadata.csv")
assert result.num_rows > 0
Expand Down Expand Up @@ -132,7 +97,7 @@ def test_copy_into_for_ryan(self):
""")
assert result.num_rows != 0

def test_copy_into(self):
def test_clickbench(self):
with universql_connection(warehouse=None) as conn:
result = execute_query(conn, """
CREATE TEMP TABLE hits2 AS SELECT
Expand Down
10 changes: 4 additions & 6 deletions tests/integration/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

class TestCreate:
def test_create_iceberg_table(self):
EXTERNAL_VOLUME_NAME = os.getenv("EXTERNAL_VOLUME_NAME")
external_volume = os.getenv("PYTEST_EXTERNAL_VOLUME")
if external_volume is None:
pytest.skip("No external volume provided, set PYTEST_EXTERNAL_VOLUME")

with universql_connection(warehouse=None) as conn:
execute_query(conn, f"""
CREATE OR REPLACE ICEBERG TABLE test_iceberg_table
external_volume = {EXTERNAL_VOLUME_NAME}
external_volume = {external_volume}
catalog = 'SNOWFLAKE'
BASE_LOCATION = 'test_iceberg_table'
AS {SIMPLE_QUERY}
Expand All @@ -33,7 +35,3 @@ def test_create_native_table(self):
with universql_connection(warehouse=None) as conn:
with pytest.raises(ProgrammingError, match="DuckDB can't create native Snowflake tables"):
execute_query(conn, f"CREATE TABLE test_native_table AS {SIMPLE_QUERY}")

@pytest.mark.skip(reason="not implemented")
def test_create_stage(self):
pass
50 changes: 44 additions & 6 deletions tests/integration/object_identifiers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,51 @@
from itertools import product

import pytest

from tests.integration.utils import execute_query, universql_connection, snowflake_connection, \
generate_select_statement_combos
from dotenv import load_dotenv
from tests.integration.utils import execute_query, universql_connection
import os


def generate_name_variants(name):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
return [lowercase, uppercase, mixed_case, in_quotes]

def generate_select_statement_combos(sets_of_identifiers, connected_db=None, connected_schema=None):
select_statements = []
for set in sets_of_identifiers:
set_of_select_statements = []
database = set.get("database")
schema = set.get("schema")
table = set.get("table")
if table is not None:
table_variants = generate_name_variants(table)
if database == connected_db and schema == connected_schema:
for table_variant in table_variants:
set_of_select_statements.append(f"SELECT * FROM {table_variant}")
else:
raise Exception("No table name provided for a select statement combo.")

if schema is not None:
schema_variants = generate_name_variants(schema)
if database == connected_db:
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
elif database is not None:
raise Exception("You must provide a schema name if you provide a database name.")

if database is not None:
database_variants = generate_name_variants(database)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
select_statements = select_statements + set_of_select_statements

return select_statements

class TestObjectIdentifiers:
def test_querying_in_connected_db_and_schema(self):
external_volume = os.getenv("PYTEST_EXTERNAL_VOLUME")
Expand Down Expand Up @@ -36,8 +76,6 @@ def test_querying_in_connected_db_and_schema(self):
select_statements = generate_select_statement_combos(combos, connected_db, connected_schema)
successful_queries = []
failed_queries = []

# create toml file
with universql_connection(database=connected_db, schema=connected_schema) as conn:
execute_query(conn, f"""
CREATE DATABASE IF NOT EXISTS universql1;
Expand All @@ -61,7 +99,7 @@ def test_querying_in_connected_db_and_schema(self):
CREATE ICEBERG TABLE IF NOT EXISTS universql2.another_schema.another_dim_devices("1" int)
external_volume = {external_volume}
catalog = 'SNOWFLAKE'
BASE_LOCATION = ' universql2.another_schema.another_dim_devices'
BASE_LOCATION = 'universql2.another_schema.another_dim_devices'
AS select 1;
""")

Expand Down
48 changes: 0 additions & 48 deletions tests/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,54 +187,6 @@ def compare_results(snowflake_result: pyarrow.Table, universql_result: pyarrow.T
print("Results match perfectly!")


def generate_name_variants(name, include_blank=False):
lowercase = name.lower()
uppercase = name.upper()
mixed_case = name.capitalize()
in_quotes = '"' + name.upper() + '"'
print([lowercase, uppercase, mixed_case, in_quotes])
return [lowercase, uppercase, mixed_case, in_quotes]


def generate_select_statement_combos(sets_of_identifiers, connected_db=None, connected_schema=None):
select_statements = []
for set in sets_of_identifiers:
set_of_select_statements = []
database = set.get("database")
schema = set.get("schema")
table = set.get("table")
if table is not None:
table_variants = generate_name_variants(table)
if database == connected_db and schema == connected_schema:
for table_variant in table_variants:
set_of_select_statements.append(f"SELECT * FROM {table_variant}")
else:
raise Exception("No table name provided for a select statement combo.")

if schema is not None:
schema_variants = generate_name_variants(schema)
if database == connected_db:
object_name_combos = product(schema_variants, table_variants)
for schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {schema_name}.{table_name}")
else:
if database is not None:
raise Exception("You must provide a schema name if you provide a database name.")

if database is not None:
database_variants = generate_name_variants(database)
object_name_combos = product(database_variants, schema_variants, table_variants)
for db_name, schema_name, table_name in object_name_combos:
set_of_select_statements.append(f"SELECT * FROM {db_name}.{schema_name}.{table_name}")
select_statements = select_statements + set_of_select_statements
logger.info(f"database: {database}, schema: {schema}, table: {table}")
for statement in set_of_select_statements:
logger.info(statement)
# logger.info(f"database: {database}, schema: {schema}, table: {table}")

return select_statements


def _set_connection_name(connection_dict={}):
snowflake_connection_name = connection_dict.get("snowflake_connection_name", SNOWFLAKE_CONNECTION_NAME)
logger.info(f"Using the {snowflake_connection_name} connection")
Expand Down
12 changes: 7 additions & 5 deletions universql/warehouse/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def get_identifier(is_quoted):
if table_exists[0] is None:
return TableType.LOCAL

match = re.search(r'CREATE VIEW (?:["\w]+\.)?[\w.]+ AS SELECT \* FROM iceberg_scan\(\'(s3://[^\']+)\'\);',
match = re.search(r"CREATE VIEW (?:[\"\w]+\.)?[\w.]+ AS SELECT \* FROM iceberg_scan\('([a-zA-Z]+://[^\']+)'\);",
table_exists[0])
if match is not None:
return TableType.ICEBERG
Expand Down Expand Up @@ -317,10 +317,12 @@ def execute(self, ast: sqlglot.exp.Expression, catalog_executor: Executor, locat
this=sqlglot.exp.parse_identifier(column.name),
kind=DataType.build(str(column.field_type)))
for column in create_iceberg_table.metadata.schema().columns]
schema = Schema()
schema.set('this', ast.this)
schema.set('expressions', column_definitions)
ast.set('this', schema)
if not isinstance(ast.this, Schema):
# c
schema = Schema()
schema.set('this', ast.this)
schema.set('expressions', column_definitions)
ast.set('this', schema)
properties.expressions.append(
Property(this=Var(this='METADATA_FILE_PATH'), value=Literal.string(metadata_file_path)))
return {destination_table: ast}
Expand Down

0 comments on commit 53a0b30

Please sign in to comment.