diff --git a/universql/protocol/session.py b/universql/protocol/session.py index 3e7cace..067a708 100644 --- a/universql/protocol/session.py +++ b/universql/protocol/session.py @@ -15,7 +15,7 @@ from universql.lake.cloud import CACHE_DIRECTORY_KEY, MAX_CACHE_SIZE from universql.util import get_friendly_time_since, \ - prepend_to_lines, parse_compute, QueryError, full_qualifier + prepend_to_lines, parse_compute, QueryError, full_qualifier, get_secrets_from_credentials_file from universql.warehouse import Executor, Tables, ICatalog from universql.warehouse.bigquery import BigQueryCatalog from universql.warehouse.duckdb import DuckDBCatalog @@ -31,8 +31,6 @@ class UniverSQLSession: def __init__(self, context, session_id, credentials: dict, session_parameters: dict): - print("context INCOMING") - pp(context) self.context = context self.credentials = credentials self.session_parameters = [{"name": item[0], "value": item[1]} for item in session_parameters.items()] @@ -216,6 +214,13 @@ def perform_query(self, alternative_executor: Executor, raw_query, ast=None) -> if files_list is not None: with sentry_sdk.start_span(op=op_name, name="Get file info"): processed_file_data = self.catalog.get_file_info(files_list) + for file_name, file_config in processed_file_data.items(): + metadata = file_config["METADATA"] + if metadata["storage_provider"] != "Amazon S3": + raise Exception("Universql currently only supports Amazon S3 stages.") + aws_role = metadata["AWS_ROLE"] + secrets = get_secrets_from_credentials_file(aws_role) + metadata.update(secrets) with sentry_sdk.start_span(op=op_name, name="Get table paths"): table_locations = self.get_table_paths_from_catalog(alternative_executor.catalog, tables_list) diff --git a/universql/util.py b/universql/util.py index 39b92ee..f456cfc 100644 --- a/universql/util.py +++ b/universql/util.py @@ -2,19 +2,24 @@ import gzip import json import os +import platform +import configparser import re import time from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import List, Tuple +from pprint import pp import humanize import psutil import sqlglot from starlette.exceptions import HTTPException from starlette.requests import Request +from dotenv import load_dotenv +load_dotenv() class Compute(Enum): LOCAL = "local" @@ -185,6 +190,11 @@ class Catalog(Enum): } ] +DEFAULT_CREDENTIALS_LOCATIONS = { + 'Darwin': "~/.aws/credentials", + 'Linux': "~/.aws/credentials", + 'Windows': "%USERPROFILE%\.aws\credentials", +} # parameters = [{"name": "TIMESTAMP_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM"}, # {"name": "CLIENT_PREFETCH_THREADS", "value": 4}, {"name": "TIME_OUTPUT_FORMAT", "value": "HH24:MI:SS"}, @@ -457,4 +467,99 @@ def full_qualifier(table: sqlglot.exp.Table, credentials: dict): db = sqlglot.exp.Identifier(this=credentials.get('schema')) \ if table.args.get('db') is None else table.args.get('db') new_table = sqlglot.exp.Table(catalog=catalog, db=db, this=table.this) - return new_table \ No newline at end of file + return new_table + +def get_secrets_from_credentials_file(aws_role_arn): + + # Get credentials file location and read TOML file + creds_file = get_credentials_file_location() + try: + config = configparser.ConfigParser() + config.read(creds_file) + except Exception as e: + raise Exception(f"Failed to read credentials file: {str(e)}") + + # Find which profile has our target role_arn + target_profile = None + for profile_name, profile_data in config.items(): + if profile_data.get('role_arn') == aws_role_arn: + target_profile = profile_name + break + + if not target_profile: + if not target_profile: + # Try default profile if target role not found + if config.has_section('default'): + return _find_credentials_in_file('default', config, creds_file) + raise Exception( + f"We were unable to find credentials for {aws_role_arn} in {creds_file}." + "Please make sure you have this role_arn in your credentials file or have a default profile configured." + "You can set the environment variable AWS_SHARED_CREDENTIALS_FILE to a different credentials file location." + ) + + return _find_credentials_in_file(target_profile, config, creds_file) + +def _find_credentials_in_file(profile_name, config, creds_file, visited=None): + """ + Recursive function to find credentials either directly in a profile + or by following source_profile references. + + Args: + profile_name: Name of profile to check + config: Configuration dictionary from TOML file + creds_file: Path to credentials file + visited: Set of profiles already checked (prevents infinite loops) + """ + + credentials_not_found_message = f"""The profile {profile_name} cannot be found in your credentials file located at {creds_file}. + Please update your credentials and try again.""" + + # Initialize visited profiles set on first call + if visited is None: + visited = set() + + # Check for circular dependencies + if profile_name in visited: + raise Exception( + f"You have a circular dependency in your credentials file between the following profiles that you need to correct:" + f"{", ".join(visited)}" + ) + visited.add(profile_name) + + # Get profile data + if not config.has_section(profile_name): + raise Exception(credentials_not_found_message) + + # Case 1: Profile has credentials directly + if config.has_option(profile_name, 'aws_access_key_id') and config.has_option(profile_name, 'aws_secret_access_key'): + return { + 'profile': profile_name, + 'access_key': config.get(profile_name, 'aws_access_key_id'), + 'secret_key': config.get(profile_name, 'aws_secret_access_key') + } + + # Case 2: Profile references another profile for credentials + if config.has_option(profile_name, 'source_profile'): + return _find_credentials_in_file(profile['source_profile'], config, creds_file, visited) + + # Case 3: No credentials found + raise Exception(credentials_not_found_message) + +def get_credentials_file_location(): + # Check for environment variable + credentials_file_location = os.environ.get("AWS_SHARED_CREDENTIALS_FILE") + if credentials_file_location is not None: + return os.path.expandvars(os.path.expanduser(credentials_file_location)) + + # fallback to default if it's not set + operating_system = platform.system() + credentials_file_location = DEFAULT_CREDENTIALS_LOCATIONS.get(operating_system) + if credentials_file_location is not None: + print("credentials_file_location INCOMING") + pp(os.path.expandvars(os.path.expanduser(credentials_file_location))) + return os.path.expandvars(os.path.expanduser(credentials_file_location)) + + raise Exception( + "Universql is unable to determine your credentials file location." + "Please set the environment variable AWS_SHARED_CREDENTIALS_FILE to your credentials file location and try again." + ) \ No newline at end of file diff --git a/universql/warehouse/duckdb.py b/universql/warehouse/duckdb.py index 13af874..80cb642 100644 --- a/universql/warehouse/duckdb.py +++ b/universql/warehouse/duckdb.py @@ -5,6 +5,7 @@ from enum import Enum from string import Template from typing import List +import boto3 import duckdb import pyiceberg.table @@ -377,10 +378,18 @@ def execute(self, ast: sqlglot.exp.Expression, tables: Tables, file_data = None) elif isinstance(ast, Copy): print("ast INCOMING") pp(ast) - # aws_role = file_data[0] - print("COPY INCOMING") print("file_data INCOMING") pp(file_data) + # aws_role = file_data[0] + for file_name, file_config in file_data.items(): + urls = file_config["METADATA"]["URL"] + try: + region = get_region(urls[0], file_config["METADATA"]["storage_provider"]) + print("region INCOMING") + print(region) + except Exception as e: + print(f"There was a problem accessing data for {file_name}:\n{e}") + sql = self._sync_and_transform_query(ast, tables).sql(dialect="duckdb", pretty=True) self.execute_raw(sql) else: @@ -416,3 +425,10 @@ def fix_snowflake_to_duckdb_types( def get_iceberg_read(location: pyiceberg.table.Table) -> str: return sqlglot.exp.func("iceberg_scan", sqlglot.exp.Literal.string(location.metadata_location)).sql() + +def get_region(url, storage_provider): + if storage_provider == 'Amazon S3': + bucket_name = url[5:].split("/")[0] + s3 = boto3.client('s3') + region_dict = s3.get_bucket_location(Bucket=bucket_name) + return region_dict.get('LocationConstraint') or 'us-east-1' \ No newline at end of file