-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathunload_rs
executable file
·419 lines (378 loc) · 18 KB
/
unload_rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
#!/usr/bin/env python3
import argparse
import datetime
import json
import logging
import logging.config
import pprint
import time
import boto3
import botocore
import sqlalchemy
logger = logging.getLogger(__name__)
class DateTimeEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, (datetime.date, datetime.datetime)):
return o.isoformat()
return super().default(o)
REDSHIFT_TO_GLUE_TYPE_MAP = {
'integer': 'int',
'smallint': 'smallint',
'bigint': 'bigint',
'numeric': 'decimal',
'real': 'float',
'double precision': 'double',
'character': 'string',
'character varying': 'string',
'timestamp without time zone': 'timestamp',
'timestamp with time zone': 'timestamp',
'boolean': 'boolean',
'date': 'timestamp', # SAD! - the Glue DATE type is not supported on non-partition cols
}
def configure_logging():
logging.config.dictConfig({
'version': 1,
'disable_existing_loggers': False,
'loggers': {
__name__: {
'handlers': ['console'],
'level': 'INFO',
'propagate': True,
},
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'level': 'INFO',
'formatter': 'simple',
},
},
'formatters': {
'simple': {
'datefmt': '%Y-%m-%d %H:%M:%S',
'format': '(%(threadName)s) %(asctime)s %(levelname)-8s [%(name)s:%(lineno)s] %(message)s',
},
},
})
def delete_source_rows(engine, unloaded_row_count, schema, table, row_filter_clause):
with engine.begin() as conn:
count = conn.scalar('SELECT COUNT(*) FROM "{schema}"."{table}" {row_filter_clause}'.format(**locals()))
if count != unloaded_row_count:
raise Exception(
'Count of rows to be deleted ({}) did not match count of rows that were unloaded ({}), suggesting '
'that rows were altered between the UNLOAD and DELETE transactions. Are you sure you\'ve selected a '
'proper audit column?'.format(count, unloaded_row_count)
)
conn.execute('DELETE FROM "{schema}"."{table}" {row_filter_clause}'.format(**locals()))
def get_opts():
p = argparse.ArgumentParser()
# Resource specification
p.add_argument('--region', required=True,
help='AWS region for the Redshift cluster and S3 bucket')
p.add_argument('--s3-bucket', required=True,
help='S3 bucket in which unloads will be stored')
p.add_argument('--s3-path-prefix',
help='Optional S3 path prefix for specifying where unloads are stored')
p.add_argument('--cluster', required=True,
help='Name of the Redshift cluster being unloaded from')
p.add_argument('--db', required=True,
help='Name of the database within the cluster that contains the table to be unloaded')
p.add_argument('--schema', required=True,
help='Name of the schema that contains the table to be unloaded')
p.add_argument('--table', required=True,
help='Name of the table to be unloaded')
# Auth
p.add_argument('--db-user', required=True,
help='Redshift user name to log in as when running UNLOADs and other queries.')
p.add_argument('--iam-role-arn', required=True,
help='ARN of the AWS IAM role used in the Redshift UNLOAD command')
# Data selection
p.add_argument('--audit-col',
help='Optional name of the column used to select a subset of the table\'s records for unloading '
'(e.g., a last_updated timestamp or similar)')
p.add_argument('--audit-min-val',
help='Optional, inclusive minimum value of the `audit_col` used to select records for unloading. '
'If absent, no minimum bound is applied. This will be used exactly as provided in the WHERE clause '
'of SQL queries, so the source DB will be responsible for interpreting e.g. its time zone.')
p.add_argument('--audit-max-val',
help='Optional, exclusive maximum value of the `audit_col` used to select records for unloading. '
'If absent, no maximum bound is applied. This will be used exactly as provided in the WHERE clause '
'of SQL queries, so the source DB will be responsible for interpreting e.g. its time zone.')
p.add_argument('--delete-on-success', action='store_true', default=False,
help='If specified, unloaded rows will be deleted from the SOURCE Redshift tables if the unload '
'finishes successfully')
p.add_argument('--keep-on-failure', action='store_true', default=False,
help='If specified, unloaded rows will be kept in the TARGET S3 (i.e. Spectrum) table even if the '
'unload validation step fails.')
opts = p.parse_args()
logger.info('Running with options: \n%s', pprint.pformat(vars(opts)))
return opts
def get_row_filter_clause(audit_col=None, audit_min_val=None, audit_max_val=None):
filter_clause = ''
if audit_col and (audit_min_val or audit_max_val):
filter_clause += 'WHERE 1=1 '
if audit_min_val:
filter_clause += 'AND "{audit_col}" >= \'{audit_min_val}\' '
if audit_max_val:
filter_clause += 'AND "{audit_col}" < \'{audit_max_val}\' '
return filter_clause.format(**locals())
def get_source_schema(engine, schema, table):
with engine.begin() as conn:
conn.execute('SET search_path TO {}'.format(schema))
res = conn.execute(
'SELECT "column", "type", "notnull" FROM pg_table_def WHERE "schemaname" = \'{}\' AND '
'"tablename" = \'{}\''.format(schema, table)
).fetchall()
schema = []
for row in res:
type_parts = row['type'].split('(')
schema.append({
'name': row['column'],
'base_type': type_parts[0],
'type_params': '(' + type_parts[1] if len(type_parts) > 1 else None,
'nullable': not row['notnull']
})
return schema
def get_sqlalchemy_engine(region, cluster, db, db_user):
client = boto3.client('redshift', region_name=region)
desc_response = client.describe_clusters(ClusterIdentifier=cluster)
creds_response = client.get_cluster_credentials(
DbUser=db_user,
ClusterIdentifier=cluster,
DurationSeconds=3600
)
url = sqlalchemy.engine.url.URL(
drivername='postgresql+pg8000',
username=creds_response['DbUser'],
password=creds_response['DbPassword'],
host=desc_response['Clusters'][0]['Endpoint']['Address'],
port=desc_response['Clusters'][0]['Endpoint']['Port'],
database=db
)
return sqlalchemy.create_engine(url, connect_args={'ssl': True})
def main(opts):
invocation_epoch = int(round(time.time() * 1000))
unload_path = 's3://{}/'.format('/'.join(filter(None, [ # Filter to omit the s3_path_prefix if not specified
opts.s3_bucket, opts.s3_path_prefix, opts.cluster, opts.db, opts.schema, opts.table, str(invocation_epoch)
])))
row_filter_clause = get_row_filter_clause(opts.audit_col, opts.audit_min_val, opts.audit_max_val)
sqlalchemy_engine = get_sqlalchemy_engine(opts.region, opts.cluster, opts.db, opts.db_user)
source_schema = get_source_schema(sqlalchemy_engine, opts.schema, opts.table)
run_metadata = {
'invocation_epoch': invocation_epoch,
'unload_path': unload_path,
'source_schema': source_schema,
'audit_col': opts.audit_col
}
try:
unload_metadata = unload_data(
sqlalchemy_engine, opts.schema, opts.table, unload_path, opts.iam_role_arn, source_schema,
opts.audit_col, row_filter_clause
)
run_metadata.update(unload_metadata)
row_count = unload_metadata['unloaded_row_count']
register_glue_table(
opts.region, opts.cluster, opts.db, opts.schema, opts.table, source_schema, unload_path,
opts.iam_role_arn, sqlalchemy_engine
)
validate_unload(
sqlalchemy_engine, opts.schema, opts.table, row_filter_clause, row_count, unload_path, run_metadata,
opts.keep_on_failure
)
if opts.delete_on_success:
delete_source_rows(
sqlalchemy_engine, row_count, opts.schema, opts.table, row_filter_clause
)
run_metadata['source_rows_deleted'] = True
else:
run_metadata['source_rows_deleted'] = False
run_metadata['success'] = True
except Exception as e:
run_metadata['success'] = False
run_metadata['exception'] = str(e)
logger.error('Exception encountered, dumping metadata:\n%s', pprint.pformat(run_metadata))
raise
finally:
save_metadata(opts.region, unload_path, run_metadata)
def register_glue_table(region, cluster, db, schema, table, source_schema, unload_path, iam_role_arn, engine):
glue_db_name = '{cluster}_{db}_{schema}_unloads'.format(**locals())
ensure_external_db_query = r"""
CREATE EXTERNAL SCHEMA IF NOT EXISTS {schema}_external
FROM DATA CATALOG
DATABASE '{glue_db_name}'
IAM_ROLE '{iam_role_arn}'
CREATE EXTERNAL DATABASE IF NOT EXISTS
""".format(**locals())
with engine.begin() as conn:
conn.execute(ensure_external_db_query)
glue_table_spec = {
"DatabaseName": glue_db_name,
"TableInput": {
"Name": table,
"Description": "Data unloaded from Redshift table {}.{}".format(schema, table),
"StorageDescriptor": {
"OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
"SerdeInfo": {
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
"Parameters": {
"escape.delim": "\\",
"serialization.format": "|",
"field.delim": "|",
"serialization.null.format": "\\N"
}
},
"BucketColumns": [],
"Location": unload_path.rsplit('/', 2)[0] + '/',
"NumberOfBuckets": -1,
"Parameters": {},
"Columns": [
{
"Type": REDSHIFT_TO_GLUE_TYPE_MAP[col['base_type']] + (
col['type_params'] if col['base_type'] == 'numeric' else ''
),
"Name": col['name']
}
for col in source_schema
]
},
"TableType": "EXTERNAL_TABLE",
"PartitionKeys": [],
"Parameters": {
"compressionType": "gzip",
"EXTERNAL": "TRUE",
},
"Owner": "owner",
}
}
client = boto3.client('glue', region_name=region)
try:
client.get_table(DatabaseName=glue_db_name, Name=table)
client.update_table(**glue_table_spec)
except botocore.exceptions.ClientError as e:
if e.response['Error']['Code'] == 'EntityNotFoundException':
client.create_table(**glue_table_spec)
else:
raise
logger.info(
'Registered table "%s" in AWS Glue database "%s". Table can now be queried in Redshift as "%s_external.%s"',
table, glue_db_name, schema, table
)
def save_metadata(region, unload_path, run_metadata):
client = boto3.client('s3', region_name=region)
_, _, bucket, key = unload_path.split('/', 3)
# Need to start it w/an underscore to keep Spectrum from treating it as real table data:
key += '_unload_metadata.json'
body = json.dumps(run_metadata, indent=2, cls=DateTimeEncoder)
client.put_object(Body=body, Bucket=bucket, Key=key)
def unload_data(engine, schema, table, unload_path, iam_role_arn, source_schema, audit_col=None,
row_filter_clause=''):
with engine.begin() as conn:
# Capture audit stats
audit_clause = (', MIN("{0}") AS audit_min_val'
', MAX("{0}") AS audit_max_val'.format(audit_col) if audit_col else '')
stats_query = r"""
SELECT COUNT(*) AS unloaded_row_count {audit_clause} FROM "{schema}"."{table}" {row_filter_clause}
""".format(**locals())
metadata = dict(conn.execute(stats_query).fetchone())
logger.info('Metadata for rows to be unloaded:\n%s', pprint.pformat(metadata))
# Perform unload. We are, sadly, compelled to do our own escaping of character columns because the
# ESCAPE option to the Redshift UNLOAD command will also escape the value given for the NULL AS option,
# making it impossible to specify a NULL AS value that is guaranteed to always roundtrip properly. (Also
# tools like Spectrum seem to not recognize other null values.)
# LEFT is to get around adding chars to a varchar(max), which results in an error.
cols = []
escape_character_col_spec = r"""
REPLACE(REPLACE(REPLACE(REPLACE(LEFT("{}", 50000), '\\', '\\\\'), '|', '\\|'), '\n', '\\\n'), '\r', '\\\r')
"""
for col in source_schema:
if col['base_type'].startswith('character'):
cols.append(escape_character_col_spec.format(col['name']))
else:
cols.append(r'"{}"'.format(col['name']))
cols = r','.join(cols)
inner_unload_query = r'SELECT {cols} FROM "{schema}"."{table}" {row_filter_clause}'.format(**locals())
inner_unload_query = inner_unload_query.replace('\\', '\\\\').replace("'", r"\'")
unload_query = r"""
UNLOAD ('{inner_unload_query}')
TO '{unload_path}'
IAM_ROLE '{iam_role_arn}'
MAXFILESIZE 1 gb
DELIMITER '|' NULL AS '\\N' MANIFEST GZIP
""".format(**locals())
logger.debug('Unloading with query:%s', unload_query)
conn.execute(unload_query)
metadata['unload_query_executed'] = unload_query
s3 = boto3.resource('s3')
_, _, bucket, key = unload_path.split('/', 3)
old_key = key + 'manifest'
new_key = key + '_manifest.json' # underscore prevents Spectrum thinking it's table data
manifest_source = s3.Object(bucket, old_key).get()['Body'].read()
s3.Object(bucket, new_key).put(Body=manifest_source)
s3.Object(bucket, old_key).delete()
metadata['manifest'] = json.loads(manifest_source.decode('utf-8'))
logger.info(
'%s rows unloaded from %s.%s to %s*', metadata['unloaded_row_count'],
schema, table, unload_path
)
return metadata
def validate_unload(engine, schema, table, row_filter_clause, row_count, unload_path, run_metadata,
keep_on_failure=False):
try:
with engine.begin() as conn:
# Count check:
external_count = conn.scalar(
'SELECT COUNT(*) FROM "{schema}_external"."{table}" {row_filter_clause}'.format(**locals())
)
if external_count != row_count:
raise Exception(
'Count of rows selected from the Spectrum table ({}) did not match the count of rows that were '
'unloaded from Redshift ({}).'.format(external_count, row_count)
)
# Schema check: (NB: the SQLAlchemy inspector doesn't work on Spectrum tables, hence this hackiness)
source_schema = {col[0]: col[1] for col in conn.execute(
'SELECT * FROM "{schema}"."{table}" {row_filter_clause} LIMIT 1'.format(**locals())
).cursor.description}
external_schema = {col[0]: col[1] for col in conn.execute(
'SELECT * FROM "{schema}_external"."{table}" {row_filter_clause} LIMIT 1'.format(**locals())
).cursor.description}
if external_schema.keys() != source_schema.keys():
raise Exception(
'There is a column name and/or count mismatch between the constructed Spectrum table and the '
'source Redshift table. The lists are:\n Spectrum:\n{}\n Source:\n{}\n'.format(
external_schema.keys(), source_schema.keys()
)
)
if external_schema != source_schema:
logger.warning(
'There are one or more column type mismatches between the Spectrum table and the source Redshift '
'table. This is expected for some types that don\'t map well from Redshift to Glue, such as '
'"timestamp with time zone", "date" and "character", but look closely to make sure your '
'expectations are not violated here! The schemas are:\n Spectrum:\n%s\n Source:\n%s\n',
external_schema, source_schema
)
run_metadata['validation_succeeded'] = True
run_metadata['unload_was_deleted'] = False
except Exception:
run_metadata['validation_succeeded'] = False
if keep_on_failure:
logger.warning(
'Validation failed but the files unloaded to S3 are being kept since --keep-on-failure was specified. '
'Once you have investigated, you may wish to delete unloaded data from S3 path %s', unload_path
)
run_metadata['unload_was_deleted'] = False
else:
logger.info(
'Rolling back unload due to validation failure. Deleting data files from S3 path %s', unload_path
)
s3 = boto3.resource('s3')
for entry in run_metadata['manifest']['entries']:
path = entry['url']
_, _, bucket, key = path.split('/', 3)
s3.Object(bucket, key).delete()
run_metadata['unload_was_deleted'] = True
raise
if __name__ == '__main__':
configure_logging()
opts = get_opts()
main(opts)