Skip to content

Commit

Permalink
read from json works if the table column names and json file keys hav…
Browse files Browse the repository at this point in the history
…e the same case
  • Loading branch information
Ryan Waldorf committed Jan 21, 2025
1 parent ed48b0c commit 1eee9f4
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 21 deletions.
4 changes: 4 additions & 0 deletions universql/protocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ def perform_query(self, alternative_executor: Executor, raw_query, ast=None) ->
else:
tables = self._find_tables(ast)
files_list = self._find_files(ast)
print("ast INCOMING")
pp(ast)
print("files_list")
pp(files_list)
tables_list = [table[0] for table in tables]
must_run_on_catalog = must_run_on_catalog or self._must_run_on_catalog(tables_list, ast)
if not must_run_on_catalog:
Expand Down
6 changes: 0 additions & 6 deletions universql/protocol/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,6 @@ def get_field_from_duckdb(column: list[str], arrow_table: Table, idx: int) -> ty
(field_name, field_type) = column[0], column[1]
pa_type = arrow_table.schema[idx].type

print(f"\nProcessing column: {field_name}")
print(f"Full column info: {column}")
print(f"DuckDB field type: {field_type}")
print(f"Current PyArrow type: {pa_type}")
print(f"Arrow table column data: {arrow_table[idx]}")

metadata = {}
value = arrow_table[idx]

Expand Down
33 changes: 25 additions & 8 deletions universql/warehouse/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Var, Literal, IcebergProperty, Copy, Delete, Merge, Use, DataType, ColumnDef

from universql.warehouse import ICatalog, Executor, Locations, Tables
from universql.warehouse.utils import get_stage_name, transform_copy
from universql.warehouse.utils import transform_copy, get_file_format, get_load_file_format_queries
from universql.lake.cloud import s3, gcs, in_lambda
from universql.util import prepend_to_lines, QueryError, calculate_script_cost, parse_snowflake_account, full_qualifier, get_role_credentials
from universql.protocol.utils import DuckDBFunctions, get_field_from_duckdb
Expand Down Expand Up @@ -170,6 +170,9 @@ def execute_raw(self, raw_query: str, no_emulator = False) -> None:
self.catalog.current_connection = self.catalog.duckdb
else:
self.catalog.current_connection = self.catalog.emulator
logger.info(f"Debug - Catalog connection type: {type(self.catalog.current_connection)}")
print("type(self.catalog.current_connection) INCOMING")
pp(type(self.catalog.current_connection))
self.catalog.current_connection.execute(raw_query)
except duckdb.Error as e:
raise QueryError(f"Unable to run the query locally on DuckDB. {e.args}")
Expand Down Expand Up @@ -248,7 +251,18 @@ def _get_property(self, ast: sqlglot.exp.Create, name: str):
Var) and expression.this.this.casefold() == name.casefold()),
None)

def prep_duckdb_for_files(self, file_data):
file_format = get_file_format(file_data[next(iter(file_data.keys()))])
load_file_format_queries = get_load_file_format_queries(file_format)
print(f"Found {file_format} files to read")
for query in load_file_format_queries:
print(f"Executing {query}")
self.execute_raw(query)

def execute(self, ast: sqlglot.exp.Expression, tables: Tables, file_data = None) -> typing.Optional[Locations]:
# if file_data is not None:
# self.prep_duckdb_for_files(file_data)

if isinstance(ast, Create) or isinstance(ast, Insert):
if isinstance(ast.this, Schema):
destination_table = ast.this.this
Expand Down Expand Up @@ -385,27 +399,30 @@ def execute(self, ast: sqlglot.exp.Expression, tables: Tables, file_data = None)
self.catalog.emulator.execute(ast.sql(dialect="snowflake"))
self.catalog.base_catalog.clear_cache()
elif isinstance(ast, Copy):
target_table = full_qualifier(ast.this, self.catalog.credentials)
print("file_data INCOMING")
pp(file_data)
# aws_role = file_data[0]
for file_name, file_config in file_data.items():
print("hola")
urls = file_config["METADATA"]["URL"]
profile = file_config["METADATA"]["URL"]
# profile = file_config["METADATA"]["URL"]
try:
region = get_region(urls[0], file_config["METADATA"]["storage_provider"])
except Exception as e:
print(f"There was a problem accessing data for {file_name}:\n{e}")

sql = self._sync_and_transform_query(ast, tables, file_data).sql(dialect="duckdb", pretty=True)
logger.info(f"Debug - Final SQL: {sql}")
logger.info(f"Debug - File data: {file_data}")
self.execute_raw(sql, True)
else:
sql = self._sync_and_transform_query(ast, tables).sql(dialect="duckdb", pretty=True)
self.execute_raw(sql)

return None

def _load_file_format(file_format):
file_format_queries = {
"JSON": ["INSTALL json;", "LOAD json;"],
"AVRO": ["INSTALL avro FROM community;", "LOAD avro;"]
}

def get_as_table(self) -> pyarrow.Table:
arrow_table = self.catalog.arrow_table()
if arrow_table is None:
Expand Down
85 changes: 78 additions & 7 deletions universql/warehouse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@

DUCKDB_SUPPORTED_FILE_TYPES = ['CSV', 'JSON', 'AVRO', 'Parquet']

FILE_FORMAT_LOAD_QUERIES = {
"JSON": ["INSTALL json;", "LOAD json;"],
"AVRO": ["INSTALL avro FROM community;", "LOAD avro;"]
}

def get_load_file_format_queries(file_format):
return FILE_FORMAT_LOAD_QUERIES.get(file_format, [])

def transform_copy(expression, file_data):
if not expression.args.get('files'):
Expand All @@ -18,7 +25,7 @@ def transform_copy(expression, file_data):
stage_name = get_stage_name(table)
stage_file_data = file_data.get(stage_name)
metadata = stage_file_data["METADATA"]
file_type = stage_file_data["format"]["snowflake_property_value"]
file_type = get_file_format(stage_file_data)
if file_type not in DUCKDB_SUPPORTED_FILE_TYPES:
raise Exception(f"DuckDB currently does not support reading from {file_type} files.")
url = metadata["URL"][0]
Expand Down Expand Up @@ -56,16 +63,22 @@ def convert_copy_params(params):
)
return copy_params

def get_file_format(params):
return params["format"]["duckdb_property_value"]

def apply_param_post_processing(params):
format = get_file_format(params)
params = _remove_problematic_params(params, format)
params = _add_required_params(params, format)
params = _add_empty_field_as_null_to_nullstr(params)
return params

def _add_empty_field_as_null_to_nullstr(params):
empty_field_as_null = params.get("EMPTY_FIELD_AS_NULL")
del params["EMPTY_FIELD_AS_NULL"]
if empty_field_as_null is None:
return params


del params["EMPTY_FIELD_AS_NULL"]
snowflake_value = empty_field_as_null.get("snowflake_property_value")
if snowflake_value.lower() == 'true':
nullstr = params.get("nullstr")
Expand All @@ -78,6 +91,24 @@ def _add_empty_field_as_null_to_nullstr(params):

return params

def _remove_problematic_params(params, format):
disallowed_params = DISALLOWED_PARAMS_BY_FORMAT.get(format, {})
for disallowed_param, disallowed_values in disallowed_params.items():
if params.get(disallowed_param) is not None:
if disallowed_values[0] in ("ALWAYS_REMOVE"):
del params[disallowed_param]
continue
param_current_value = params[disallowed_param]["duckdb_property_value"]
if param_current_value in disallowed_values:
del params[disallowed_param]
return params

def _add_required_params(params, format):
# required_params = REQUIRED_PARAMS_BY_FORMAT.get(format, {})
# for required_param, required_values in required_params.items():
# params[required_param] = required_values
return params

def get_stage_info(file, file_format_params, cursor):
if file.get("type") != "STAGE" and file.get("source_catalog") != "SNOWFLAKE":
raise Exception("There was an issue processing your file data.")
Expand All @@ -104,11 +135,12 @@ def get_stage_info(file, file_format_params, cursor):
return duckdb_data

def convert_to_duckdb_properties(copy_properties):
file_format = copy_properties['TYPE']['snowflake_property_value']
all_converted_properties = {}
metadata = {}

for snowflake_property_name, snowflake_property_info in copy_properties.items():
converted_properties = convert_properties(snowflake_property_name, snowflake_property_info)
converted_properties = convert_properties(file_format, snowflake_property_name, snowflake_property_info)
duckdb_property_name, property_values = next(iter(converted_properties.items()))
if property_values["duckdb_property_type"] == 'METADATA':
metadata[duckdb_property_name] = property_values["duckdb_property_value"]
Expand All @@ -124,7 +156,7 @@ def convert_to_duckdb_properties(copy_properties):
all_converted_properties["METADATA"] = metadata
return all_converted_properties

def convert_properties(snowflake_property_name, snowflake_property_info):
def convert_properties(file_format, snowflake_property_name, snowflake_property_info):
no_match = {
"duckdb_property_name": None,
"duckdb_property_type": None
Expand All @@ -136,13 +168,13 @@ def convert_properties(snowflake_property_name, snowflake_property_info):
"duckdb_property_type": duckdb_property_type
} | snowflake_property_info | {"snowflake_property_name": snowflake_property_name}
if duckdb_property_name is not None:
value = _format_value_for_duckdb(snowflake_property_name, properties)
value = _format_value_for_duckdb(file_format, snowflake_property_name, properties)
properties["duckdb_property_value"] = value
else:
properties["duckdb_property_value"] = None
return {duckdb_property_name: properties}

def _format_value_for_duckdb(snowflake_property_name, data):
def _format_value_for_duckdb(file_format, snowflake_property_name, data):
snowflake_type = data["snowflake_property_type"]
duckdb_type = data["duckdb_property_type"]
snowflake_value = data["snowflake_property_value"]
Expand All @@ -152,6 +184,8 @@ def _format_value_for_duckdb(snowflake_property_name, data):
duckdb_value.replace(snowflake_datetime_component, duckdb_datetime_component)
return duckdb_value
elif snowflake_type == 'String' and duckdb_type == 'VARCHAR':
if file_format == 'JSON' and snowflake_property_name.lower() == 'compression' and snowflake_value.lower() == 'auto':
return "auto_detect"
return _format_string_for_duckdb(snowflake_value)
elif snowflake_type == "Boolean" and duckdb_type == 'BOOL':
return snowflake_value.lower()
Expand Down Expand Up @@ -205,6 +239,23 @@ def get_file_path(file: sqlglot.exp.Table):
return full_string[i + 1:]
return ""

DISALLOWED_PARAMS_BY_FORMAT = {
"JSON": {
"ignore_errors": ["ALWAYS_REMOVE"],
"nullstr": ["ALWAYS_REMOVE"],
"timestampformat": ["AUTO"]
}
}

REQUIRED_PARAMS_BY_FORMAT = {
"JSON": {
"auto_detect": {
"duckdb_property_type": "BOOL",
"duckdb_property_value": "TRUE"
}
}
}

SNOWFLAKE_TO_DUCKDB_DATETIME_MAPPINGS = {
'YYYY': '%Y',
'YY': '%y',
Expand Down Expand Up @@ -382,6 +433,26 @@ def get_file_path(file: sqlglot.exp.Table):
"AUTO_REFRESH": {
"duckdb_property_name": None,
"duckdb_property_type": None
},
"ALLOW_DUPLICATE": { # duckdb only takes the last value
"duckdb_property_name": None,
"duckdb_property_type": None
},
"ENABLE_OCTAL": {
"duckdb_property_name": None,
"duckdb_property_type": None
},
"IGNORE_UTF8_ERRORS": {
"duckdb_property_name": None,
"duckdb_property_type": None
},
"STRIP_NULL_VALUES": { # would need to be handled after a successful copy
"duckdb_property_name": None,
"duckdb_property_type": None
},
"STRIP_OUTER_ARRAY": { # needs to use json_array_elements() after loading
"duckdb_property_name": None,
"duckdb_property_type": None
}
}

Expand Down

0 comments on commit 1eee9f4

Please sign in to comment.