Skip to content

Commit

Permalink
credentials passed to duckdb query executor
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanwithawhy committed Jan 12, 2025
1 parent 8050295 commit d5e0e52
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 6 deletions.
11 changes: 8 additions & 3 deletions universql/protocol/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from universql.lake.cloud import CACHE_DIRECTORY_KEY, MAX_CACHE_SIZE
from universql.util import get_friendly_time_since, \
prepend_to_lines, parse_compute, QueryError, full_qualifier
prepend_to_lines, parse_compute, QueryError, full_qualifier, get_secrets_from_credentials_file
from universql.warehouse import Executor, Tables, ICatalog
from universql.warehouse.bigquery import BigQueryCatalog
from universql.warehouse.duckdb import DuckDBCatalog
Expand All @@ -31,8 +31,6 @@

class UniverSQLSession:
def __init__(self, context, session_id, credentials: dict, session_parameters: dict):
print("context INCOMING")
pp(context)
self.context = context
self.credentials = credentials
self.session_parameters = [{"name": item[0], "value": item[1]} for item in session_parameters.items()]
Expand Down Expand Up @@ -216,6 +214,13 @@ def perform_query(self, alternative_executor: Executor, raw_query, ast=None) ->
if files_list is not None:
with sentry_sdk.start_span(op=op_name, name="Get file info"):
processed_file_data = self.catalog.get_file_info(files_list)
for file_name, file_config in processed_file_data.items():
metadata = file_config["METADATA"]
if metadata["storage_provider"] != "Amazon S3":
raise Exception("Universql currently only supports Amazon S3 stages.")
aws_role = metadata["AWS_ROLE"]
secrets = get_secrets_from_credentials_file(aws_role)
metadata.update(secrets)

with sentry_sdk.start_span(op=op_name, name="Get table paths"):
table_locations = self.get_table_paths_from_catalog(alternative_executor.catalog, tables_list)
Expand Down
107 changes: 106 additions & 1 deletion universql/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,24 @@
import gzip
import json
import os
import platform
import configparser
import re
import time
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import List, Tuple
from pprint import pp

import humanize
import psutil
import sqlglot
from starlette.exceptions import HTTPException
from starlette.requests import Request
from dotenv import load_dotenv

load_dotenv()

class Compute(Enum):
LOCAL = "local"
Expand Down Expand Up @@ -185,6 +190,11 @@ class Catalog(Enum):
}
]

DEFAULT_CREDENTIALS_LOCATIONS = {
'Darwin': "~/.aws/credentials",
'Linux': "~/.aws/credentials",
'Windows': "%USERPROFILE%\.aws\credentials",
}

# parameters = [{"name": "TIMESTAMP_OUTPUT_FORMAT", "value": "YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM"},
# {"name": "CLIENT_PREFETCH_THREADS", "value": 4}, {"name": "TIME_OUTPUT_FORMAT", "value": "HH24:MI:SS"},
Expand Down Expand Up @@ -457,4 +467,99 @@ def full_qualifier(table: sqlglot.exp.Table, credentials: dict):
db = sqlglot.exp.Identifier(this=credentials.get('schema')) \
if table.args.get('db') is None else table.args.get('db')
new_table = sqlglot.exp.Table(catalog=catalog, db=db, this=table.this)
return new_table
return new_table

def get_secrets_from_credentials_file(aws_role_arn):

# Get credentials file location and read TOML file
creds_file = get_credentials_file_location()
try:
config = configparser.ConfigParser()
config.read(creds_file)
except Exception as e:
raise Exception(f"Failed to read credentials file: {str(e)}")

# Find which profile has our target role_arn
target_profile = None
for profile_name, profile_data in config.items():
if profile_data.get('role_arn') == aws_role_arn:
target_profile = profile_name
break

if not target_profile:
if not target_profile:
# Try default profile if target role not found
if config.has_section('default'):
return _find_credentials_in_file('default', config, creds_file)
raise Exception(
f"We were unable to find credentials for {aws_role_arn} in {creds_file}."
"Please make sure you have this role_arn in your credentials file or have a default profile configured."
"You can set the environment variable AWS_SHARED_CREDENTIALS_FILE to a different credentials file location."
)

return _find_credentials_in_file(target_profile, config, creds_file)

def _find_credentials_in_file(profile_name, config, creds_file, visited=None):
"""
Recursive function to find credentials either directly in a profile
or by following source_profile references.
Args:
profile_name: Name of profile to check
config: Configuration dictionary from TOML file
creds_file: Path to credentials file
visited: Set of profiles already checked (prevents infinite loops)
"""

credentials_not_found_message = f"""The profile {profile_name} cannot be found in your credentials file located at {creds_file}.
Please update your credentials and try again."""

# Initialize visited profiles set on first call
if visited is None:
visited = set()

# Check for circular dependencies
if profile_name in visited:
raise Exception(
f"You have a circular dependency in your credentials file between the following profiles that you need to correct:"
f"{", ".join(visited)}"
)
visited.add(profile_name)

# Get profile data
if not config.has_section(profile_name):
raise Exception(credentials_not_found_message)

# Case 1: Profile has credentials directly
if config.has_option(profile_name, 'aws_access_key_id') and config.has_option(profile_name, 'aws_secret_access_key'):
return {
'profile': profile_name,
'access_key': config.get(profile_name, 'aws_access_key_id'),
'secret_key': config.get(profile_name, 'aws_secret_access_key')
}

# Case 2: Profile references another profile for credentials
if config.has_option(profile_name, 'source_profile'):
return _find_credentials_in_file(profile['source_profile'], config, creds_file, visited)

# Case 3: No credentials found
raise Exception(credentials_not_found_message)

def get_credentials_file_location():
# Check for environment variable
credentials_file_location = os.environ.get("AWS_SHARED_CREDENTIALS_FILE")
if credentials_file_location is not None:
return os.path.expandvars(os.path.expanduser(credentials_file_location))

# fallback to default if it's not set
operating_system = platform.system()
credentials_file_location = DEFAULT_CREDENTIALS_LOCATIONS.get(operating_system)
if credentials_file_location is not None:
print("credentials_file_location INCOMING")
pp(os.path.expandvars(os.path.expanduser(credentials_file_location)))
return os.path.expandvars(os.path.expanduser(credentials_file_location))

raise Exception(
"Universql is unable to determine your credentials file location."
"Please set the environment variable AWS_SHARED_CREDENTIALS_FILE to your credentials file location and try again."
)
20 changes: 18 additions & 2 deletions universql/warehouse/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from enum import Enum
from string import Template
from typing import List
import boto3

import duckdb
import pyiceberg.table
Expand Down Expand Up @@ -377,10 +378,18 @@ def execute(self, ast: sqlglot.exp.Expression, tables: Tables, file_data = None)
elif isinstance(ast, Copy):
print("ast INCOMING")
pp(ast)
# aws_role = file_data[0]
print("COPY INCOMING")
print("file_data INCOMING")
pp(file_data)
# aws_role = file_data[0]
for file_name, file_config in file_data.items():
urls = file_config["METADATA"]["URL"]
try:
region = get_region(urls[0], file_config["METADATA"]["storage_provider"])
print("region INCOMING")
print(region)
except Exception as e:
print(f"There was a problem accessing data for {file_name}:\n{e}")

sql = self._sync_and_transform_query(ast, tables).sql(dialect="duckdb", pretty=True)
self.execute_raw(sql)
else:
Expand Down Expand Up @@ -416,3 +425,10 @@ def fix_snowflake_to_duckdb_types(
def get_iceberg_read(location: pyiceberg.table.Table) -> str:
return sqlglot.exp.func("iceberg_scan",
sqlglot.exp.Literal.string(location.metadata_location)).sql()

def get_region(url, storage_provider):
if storage_provider == 'Amazon S3':
bucket_name = url[5:].split("/")[0]
s3 = boto3.client('s3')
region_dict = s3.get_bucket_location(Bucket=bucket_name)
return region_dict.get('LocationConstraint') or 'us-east-1'

0 comments on commit d5e0e52

Please sign in to comment.