From f42f18e7be655fb45c8fc0cd619eeca6c58aa74b Mon Sep 17 00:00:00 2001 From: Marty Woodlee Date: Fri, 2 Jun 2023 09:03:00 -0500 Subject: [PATCH] CDC-to-Kafka 3.3.0 (#25) * CDC-to-Kafka 3.3.0 * Get row counts without need of VIEW DATABASE STATE permissions * Fixes from testing, before upgrading schema registry impl * Add comments and fix things flagged by grammar inspection --- cdc_kafka/avro_from_sql.py | 12 +- cdc_kafka/build_startup_state.py | 391 ++++++++++++++++++ cdc_kafka/constants.py | 14 +- cdc_kafka/helpers.py | 17 + cdc_kafka/kafka.py | 15 +- cdc_kafka/main.py | 260 +----------- cdc_kafka/metric_reporting/accumulator.py | 2 +- .../metric_reporting/http_post_reporter.py | 2 +- cdc_kafka/options.py | 12 + cdc_kafka/progress_tracking.py | 6 +- cdc_kafka/replayer.py | 302 ++++++++++++++ cdc_kafka/sql_queries.py | 84 +++- cdc_kafka/sql_query_subprocess.py | 6 +- cdc_kafka/tracked_tables.py | 46 ++- requirements.txt | 12 +- 15 files changed, 864 insertions(+), 317 deletions(-) create mode 100644 cdc_kafka/build_startup_state.py create mode 100755 cdc_kafka/replayer.py diff --git a/cdc_kafka/avro_from_sql.py b/cdc_kafka/avro_from_sql.py index c1ea029..dc5929d 100644 --- a/cdc_kafka/avro_from_sql.py +++ b/cdc_kafka/avro_from_sql.py @@ -51,7 +51,7 @@ def get_cdc_metadata_fields_avro_schemas(table_fq_name: str, source_field_names: # need not be updated. We align with that by making the Avro value schema for all captured fields nullable (which also # helps with maintaining future Avro schema compatibility). def avro_schema_from_sql_type(source_field_name: str, sql_type_name: str, decimal_precision: int, - decimal_scale: int, make_nullable: bool) -> Dict[str, Any]: + decimal_scale: int, make_nullable: bool, force_avro_long: bool) -> Dict[str, Any]: if sql_type_name in ('decimal', 'numeric', 'money', 'smallmoney'): if (not decimal_precision) or decimal_scale is None: raise Exception(f"Field '{source_field_name}': For SQL decimal, money, or numeric types, the scale and " @@ -70,14 +70,16 @@ def avro_schema_from_sql_type(source_field_name: str, sql_type_name: str, decima avro_type = "double" elif sql_type_name == 'real': avro_type = "float" + elif sql_type_name in ('int', 'smallint', 'tinyint'): + avro_type = "long" if force_avro_long else "int" + # For date and time we don't respect force_avro_long since the underlying type being `int` for these logical + # types is spelled out in the Avro spec: elif sql_type_name == 'date': avro_type = {"type": "int", "logicalType": "date"} - elif sql_type_name in ('int', 'smallint', 'tinyint'): - avro_type = "int" - elif sql_type_name in ('datetime', 'datetime2', 'datetimeoffset', 'smalldatetime', 'xml') + SQL_STRING_TYPES: - avro_type = "string" elif sql_type_name == 'time': avro_type = {"type": "int", "logicalType": "time-millis"} + elif sql_type_name in ('datetime', 'datetime2', 'datetimeoffset', 'smalldatetime', 'xml') + SQL_STRING_TYPES: + avro_type = "string" elif sql_type_name == 'uniqueidentifier': avro_type = {"type": "string", "logicalType": "uuid"} elif sql_type_name in ('binary', 'image', 'varbinary', 'rowversion'): diff --git a/cdc_kafka/build_startup_state.py b/cdc_kafka/build_startup_state.py new file mode 100644 index 0000000..20bbd76 --- /dev/null +++ b/cdc_kafka/build_startup_state.py @@ -0,0 +1,391 @@ +import collections +import json +import logging +import re +import time +from typing import Dict, List, Tuple, Iterable, Union, Optional, Any, Set + +import pyodbc +from tabulate import tabulate + +from . import sql_query_subprocess, tracked_tables, sql_queries, kafka, progress_tracking, change_index, \ + constants, helpers, options, avro_from_sql +from .metric_reporting import accumulator + +logger = logging.getLogger(__name__) + + +def build_tracked_tables_from_cdc_metadata( + db_conn: pyodbc.Connection, metrics_accumulator: 'accumulator.Accumulator', topic_name_template: str, + snapshot_table_whitelist_regex: str, snapshot_table_blacklist_regex: str, truncate_fields: Dict[str, int], + capture_instance_names: List[str], db_row_batch_size: int, force_avro_long: bool, + sql_query_processor: 'sql_query_subprocess.SQLQueryProcessor' +) -> List[tracked_tables.TrackedTable]: + result: List[tracked_tables.TrackedTable] = [] + + truncate_fields = {k.lower(): v for k, v in truncate_fields.items()} + + snapshot_table_whitelist_regex = snapshot_table_whitelist_regex and re.compile( + snapshot_table_whitelist_regex, re.IGNORECASE) + snapshot_table_blacklist_regex = snapshot_table_blacklist_regex and re.compile( + snapshot_table_blacklist_regex, re.IGNORECASE) + + name_to_meta_fields: Dict[Tuple, List[Tuple]] = collections.defaultdict(list) + + with db_conn.cursor() as cursor: + q, p = sql_queries.get_cdc_tracked_tables_metadata(capture_instance_names) + cursor.execute(q) + for row in cursor.fetchall(): + # 0:4 gets schema name, table name, capture instance name, min captured LSN: + name_to_meta_fields[tuple(row[0:4])].append(row[4:]) + + for (schema_name, table_name, capture_instance_name, min_lsn), fields in name_to_meta_fields.items(): + fq_table_name = f'{schema_name}.{table_name}' + + can_snapshot = False + + if snapshot_table_whitelist_regex and snapshot_table_whitelist_regex.match(fq_table_name): + logger.debug('Table %s matched snapshotting whitelist', fq_table_name) + can_snapshot = True + + if snapshot_table_blacklist_regex and snapshot_table_blacklist_regex.match(fq_table_name): + logger.debug('Table %s matched snapshotting blacklist and will NOT be snapshotted', fq_table_name) + can_snapshot = False + + topic_name = topic_name_template.format( + schema_name=schema_name, table_name=table_name, capture_instance_name=capture_instance_name) + + tracked_table = tracked_tables.TrackedTable( + db_conn, metrics_accumulator, sql_query_processor, schema_name, table_name, capture_instance_name, + topic_name, min_lsn, can_snapshot, db_row_batch_size) + + for (change_table_ordinal, column_name, sql_type_name, is_computed, primary_key_ordinal, decimal_precision, + decimal_scale, is_nullable) in fields: + truncate_after = truncate_fields.get(f'{schema_name}.{table_name}.{column_name}'.lower()) + tracked_table.append_field(tracked_tables.TrackedField( + column_name, sql_type_name, change_table_ordinal, primary_key_ordinal, decimal_precision, + decimal_scale, force_avro_long, truncate_after)) + + result.append(tracked_table) + + return result + + +def determine_start_points_and_finalize_tables( + kafka_client: kafka.KafkaClient, db_conn: pyodbc.Connection, tables: Iterable[tracked_tables.TrackedTable], + progress_tracker: progress_tracking.ProgressTracker, lsn_gap_handling: str, + partition_count: int, replication_factor: int, extra_topic_config: Dict[str, Union[str, int]], + force_avro_long: bool, validation_mode: bool = False, redo_snapshot_for_new_instance: bool = False, + publish_duplicate_changes_from_new_instance: bool = False, report_progress_only: bool = False +) -> None: + topic_names: List[str] = [t.topic_name for t in tables] + + if validation_mode: + for table in tables: + table.snapshot_allowed = False + table.finalize_table(change_index.LOWEST_CHANGE_INDEX, None, {}, lsn_gap_handling) + return + + if report_progress_only: + watermarks_by_topic = [] + else: + watermarks_by_topic = kafka_client.get_topic_watermarks(topic_names) + first_check_watermarks_json = json.dumps(watermarks_by_topic) + + logger.info(f'Pausing for {constants.WATERMARK_STABILITY_CHECK_DELAY_SECS} seconds to ensure target topics are ' + f'not receiving new messages from elsewhere...') + time.sleep(constants.WATERMARK_STABILITY_CHECK_DELAY_SECS) + + watermarks_by_topic = kafka_client.get_topic_watermarks(topic_names) + second_check_watermarks_json = json.dumps(watermarks_by_topic) + + if first_check_watermarks_json != second_check_watermarks_json: + raise Exception(f'Watermarks for one or more target topics changed between successive checks. ' + f'Another process may be producing to the topic(s). Bailing.\nFirst check: ' + f'{first_check_watermarks_json}\nSecond check: {second_check_watermarks_json}') + logger.debug('Topic watermarks: %s', second_check_watermarks_json) + + prior_progress_log_table_data = [] + prior_progress = progress_tracker.get_prior_progress_or_create_progress_topic() + + for table in tables: + snapshot_progress, changes_progress = None, None + prior_change_table_max_index: Optional[change_index.ChangeIndex] = None + + if not report_progress_only and table.topic_name not in watermarks_by_topic: # new topic; create it + if partition_count: + this_topic_partition_count = partition_count + else: + per_second = table.get_change_rows_per_second() + # one partition for each 10 rows/sec on average in the change table: + this_topic_partition_count = max(1, int(per_second / 10)) + if this_topic_partition_count > 100: + raise Exception( + f'Automatic topic creation would create %{this_topic_partition_count} partitions for topic ' + f'{table.topic_name} based on a change table rows per second rate of {per_second}. This ' + f'seems excessive, so the program is exiting to prevent overwhelming your Kafka cluster. ' + f'Look at setting PARTITION_COUNT to take manual control of this.') + logger.info('Creating topic %s with %s partition(s)', table.topic_name, this_topic_partition_count) + kafka_client.create_topic(table.topic_name, this_topic_partition_count, replication_factor, + extra_topic_config) + else: + snapshot_progress: Union[None, progress_tracking.ProgressEntry] = prior_progress.get( + (table.topic_name, constants.SNAPSHOT_ROWS_KIND)) + changes_progress: Union[None, progress_tracking.ProgressEntry] = prior_progress.get( + (table.topic_name, constants.CHANGE_ROWS_KIND)) + + fq_change_table_name = helpers.get_fq_change_table_name(table.capture_instance_name) + if snapshot_progress and (snapshot_progress.change_table_name != fq_change_table_name): + logger.info('Found prior snapshot progress into topic %s, but from an older capture instance ' + '(prior progress instance: %s; current instance: %s)', table.topic_name, + snapshot_progress.change_table_name, fq_change_table_name) + if redo_snapshot_for_new_instance: + old_capture_instance_name = helpers.get_capture_instance_name(snapshot_progress.change_table_name) + new_capture_instance_name = helpers.get_capture_instance_name(fq_change_table_name) + if ddl_change_requires_new_snapshot(db_conn, old_capture_instance_name, new_capture_instance_name, + table.fq_name, force_avro_long): + logger.info('Will start new snapshot.') + snapshot_progress = None + else: + progress_tracker.record_snapshot_completion(table.topic_name) + else: + progress_tracker.record_snapshot_completion(table.topic_name) + logger.info('Will NOT start new snapshot.') + + if changes_progress and (changes_progress.change_table_name != fq_change_table_name): + logger.info('Found prior change data progress into topic %s, but from an older capture instance ' + '(prior progress instance: %s; current instance: %s)', table.topic_name, + changes_progress.change_table_name, fq_change_table_name) + with db_conn.cursor() as cursor: + cursor.execute("SELECT 1 FROM sys.tables WHERE object_id = OBJECT_ID(?)", + changes_progress.change_table_name) + if cursor.fetchval() is not None: + q, p = sql_queries.get_max_lsn_for_change_table(changes_progress.change_table_name) + cursor.execute(q) + res = cursor.fetchone() + if res: + (lsn, _, seqval, operation) = res + prior_change_table_max_index = change_index.ChangeIndex(lsn, seqval, operation) + + if publish_duplicate_changes_from_new_instance: + logger.info('Will republish any change rows duplicated by the new capture instance.') + changes_progress = None + else: + logger.info('Will NOT republish any change rows duplicated by the new capture instance.') + + starting_change_index = (changes_progress and changes_progress.change_index) \ + or change_index.LOWEST_CHANGE_INDEX + starting_snapshot_index = snapshot_progress and snapshot_progress.snapshot_index + + if report_progress_only: # elide schema registration + table.finalize_table(starting_change_index, prior_change_table_max_index, starting_snapshot_index, + options.LSN_GAP_HANDLING_IGNORE) + else: + table.finalize_table(starting_change_index, prior_change_table_max_index, starting_snapshot_index, + lsn_gap_handling, kafka_client.register_schemas, progress_tracker.reset_progress, + progress_tracker.record_snapshot_completion) + + if not table.snapshot_allowed: + snapshot_state = '' + elif table.snapshot_complete: + snapshot_state = '' + elif table.last_read_key_for_snapshot_display is None: + snapshot_state = '' + else: + snapshot_state = f'From {table.last_read_key_for_snapshot_display}' + + prior_progress_log_table_data.append((table.capture_instance_name, table.fq_name, table.topic_name, + starting_change_index or '', snapshot_state)) + + headers = ('Capture instance name', 'Source table name', 'Topic name', 'From change table index', 'Snapshots') + table = tabulate(sorted(prior_progress_log_table_data), headers, tablefmt='fancy_grid') + + logger.info('Processing will proceed from the following positions based on the last message from each topic ' + 'and/or the snapshot progress committed in Kafka (NB: snapshot reads occur BACKWARDS from high to ' + 'low key column values):\n%s\n%s tables total.', table, len(prior_progress_log_table_data)) + + +def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, old_capture_instance_name: str, + new_capture_instance_name: str, source_table_fq_name: str, force_avro_long: bool, + resnapshot_for_column_drops: bool = True) -> bool: + with db_conn.cursor() as cursor: + cursor.execute(f'SELECT TOP 1 1 FROM [{constants.CDC_DB_SCHEMA_NAME}].[change_tables] ' + f'WHERE capture_instance = ?', old_capture_instance_name) + if not cursor.fetchval(): + logger.info('Requiring re-snapshot for %s because prior capture instance %s is no longer available as a ' + 'basis for evaluating schema changes.', source_table_fq_name, old_capture_instance_name) + return True + + q, p = sql_queries.get_cdc_tracked_tables_metadata([old_capture_instance_name, new_capture_instance_name]) + cursor.execute(q) + old_cols: Dict[str, Dict[str, Any]] = {} + new_cols: Dict[str, Dict[str, Any]] = {} + for row in cursor.fetchall(): + (_, _, capture_instance_name, _, _, column_name, sql_type_name, is_computed, _, decimal_precision, + decimal_scale, is_nullable) = row + col_info = {'sql_type_name': sql_type_name, + 'decimal_precision': decimal_precision, + 'decimal_scale': decimal_scale, + 'is_computed': is_computed, + 'is_nullable': is_nullable} + if capture_instance_name == old_capture_instance_name: + old_cols[column_name] = col_info + elif capture_instance_name == new_capture_instance_name: + new_cols[column_name] = col_info + + added_col_names = new_cols.keys() - old_cols.keys() + removed_col_names = old_cols.keys() - new_cols.keys() + changed_col_names = {k for k in new_cols.keys() + if k in old_cols + and old_cols[k] != new_cols[k]} + logger.info('Evaluating need for new snapshot in change from capture instance %s to %s. Added cols: %s Removed ' + 'cols: %s Cols with type changes: %s ...', old_capture_instance_name, new_capture_instance_name, + added_col_names, removed_col_names, changed_col_names) + + if removed_col_names and resnapshot_for_column_drops: + logger.info('Requiring re-snapshot for %s because the new capture instance removes column(s) %s.', + source_table_fq_name, removed_col_names) + return True + + for changed_col_name in changed_col_names: + old_col = old_cols[changed_col_name] + new_col = new_cols[changed_col_name] + # Even if the DB col type changed, a resnapshot is really only needed if the corresponding Avro type + # changes. An example would be a column "upgrading" from SMALLINT to INT: + old_avro_type = avro_from_sql.avro_schema_from_sql_type(changed_col_name, old_col['sql_type_name'], + old_col['decimal_precision'], + old_col['decimal_scale'], True, force_avro_long) + new_avro_type = avro_from_sql.avro_schema_from_sql_type(changed_col_name, new_col['sql_type_name'], + new_col['decimal_precision'], + new_col['decimal_scale'], True, force_avro_long) + if old_col['is_computed'] != new_col['is_computed'] or old_avro_type != new_avro_type: + logger.info('Requiring re-snapshot for %s due to a data type change for column %s (type: %s, ' + 'is_computed: %s --> type: %s, is_computed: %s).', source_table_fq_name, changed_col_name, + old_col['sql_type_name'], old_col['is_computed'], new_col['sql_type_name'], + new_col['is_computed']) + return True + + for added_col_name in added_col_names: + col_info = new_cols[added_col_name] + if not col_info['is_nullable']: + logger.info('Requiring re-snapshot for %s because newly-captured column %s is marked NOT NULL', + source_table_fq_name, added_col_name) + return True + + q, p = sql_queries.get_table_rowcount_bounded(source_table_fq_name, constants.SMALL_TABLE_THRESHOLD) + cursor.execute(q) + bounded_row_count = cursor.fetchval() + logger.debug('Bounded row count for %s was: %s', source_table_fq_name, bounded_row_count) + table_is_small = bounded_row_count < constants.SMALL_TABLE_THRESHOLD + + # Gets the names of columns that appear in the first position of one or more unfiltered, non-disabled indexes: + q, p = sql_queries.get_indexed_cols() + cursor.setinputsizes(p) + cursor.execute(q, source_table_fq_name) + indexed_cols: Set[str] = {row[0] for row in cursor.fetchall()} + recently_added_cols: Optional[Set[str]] = None + + for added_col_name in added_col_names: + if table_is_small or added_col_name in indexed_cols: + cursor.execute(f"SELECT TOP 1 1 FROM {source_table_fq_name} WITH (NOLOCK) " + f"WHERE [{added_col_name}] IS NOT NULL") + if cursor.fetchval() is not None: + logger.info('Requiring re-snapshot for %s because a direct scan of newly-tracked column %s ' + 'detected non-null values.', source_table_fq_name, added_col_name) + return True + else: + logger.info('New col %s on table %s contains only NULL values per direct check.', + added_col_name, source_table_fq_name) + else: + # if we get here it means the table is large, the new column does not lead in an index, but + # the new column is nullable. + if recently_added_cols is None: + cols_with_too_old_changes: Set[str] = set() + cols_with_new_enough_changes: Set[str] = set() + q, p = sql_queries.get_ddl_history_for_capture_table() + cursor.setinputsizes(p) + cursor.execute(q, helpers.get_fq_change_table_name(old_capture_instance_name)) + alter_re = re.compile( + r'\W*alter\s+table\s+(?P[\w\.\[\]]+)\s+add\s+(?P[\w\.\[\]]+)\s+(?P.*)', + re.IGNORECASE) + for (ddl_command, age_seconds) in cursor.fetchall(): + match = alter_re.match(ddl_command) + if match and match.groupdict().get('column'): + col_name_lower = match.groupdict()['column'].lower() + if age_seconds > constants.MAX_AGE_TO_PRESUME_ADDED_COL_IS_NULL_SECONDS: + cols_with_too_old_changes.add(col_name_lower) + else: + cols_with_new_enough_changes.add(col_name_lower) + recently_added_cols = cols_with_new_enough_changes - cols_with_too_old_changes + + if added_col_name.lower() not in recently_added_cols: + logger.info('Requiring re-snapshot for %s because newly-tracked column %s appears to have been ' + 'added more than %s seconds ago.', source_table_fq_name, added_col_name, + constants.MAX_AGE_TO_PRESUME_ADDED_COL_IS_NULL_SECONDS) + return True + else: + logger.info('New col %s on table %s is ASSUMED to contain only NULL values because of the recency ' + 'of its addition.', added_col_name, source_table_fq_name) + + logger.info('Not requiring re-snapshot for table %s.', source_table_fq_name) + return False + + +# This pulls the "greatest" capture instance running for each source table, in the event there is more than one. +def get_latest_capture_instances_by_fq_name( + db_conn: pyodbc.Connection, capture_instance_version_strategy: str, capture_instance_version_regex: str, + table_whitelist_regex: str, table_blacklist_regex: str +) -> Dict[str, Dict[str, Any]]: + if capture_instance_version_strategy == options.CAPTURE_INSTANCE_VERSION_STRATEGY_REGEX \ + and not capture_instance_version_regex: + raise Exception('Please provide a capture_instance_version_regex when specifying the `regex` ' + 'capture_instance_version_strategy.') + result: Dict[str, Dict[str, Any]] = {} + fq_name_to_capture_instances: Dict[str, List[Dict[str, Any]]] = collections.defaultdict(list) + capture_instance_version_regex = capture_instance_version_regex and re.compile(capture_instance_version_regex) + table_whitelist_regex = table_whitelist_regex and re.compile(table_whitelist_regex, re.IGNORECASE) + table_blacklist_regex = table_blacklist_regex and re.compile(table_blacklist_regex, re.IGNORECASE) + + with db_conn.cursor() as cursor: + q, p = sql_queries.get_cdc_capture_instances_metadata() + cursor.execute(q) + for row in cursor.fetchall(): + fq_table_name = f'{row[0]}.{row[1]}' + + if table_whitelist_regex and not table_whitelist_regex.match(fq_table_name): + logger.debug('Table %s excluded by whitelist', fq_table_name) + continue + + if table_blacklist_regex and table_blacklist_regex.match(fq_table_name): + logger.debug('Table %s excluded by blacklist', fq_table_name) + continue + + if row[3] is None or row[4] is None: + logger.debug('Capture instance for %s appears to be brand-new; will evaluate again on ' + 'next pass', fq_table_name) + continue + + as_dict = { + 'fq_name': fq_table_name, + 'capture_instance_name': row[2], + 'start_lsn': row[3], + 'create_date': row[4], + } + if capture_instance_version_regex: + match = capture_instance_version_regex.match(row[1]) + as_dict['regex_matched_group'] = match and match.group(1) or '' + fq_name_to_capture_instances[as_dict['fq_name']].append(as_dict) + + for fq_name, capture_instances in fq_name_to_capture_instances.items(): + if capture_instance_version_strategy == options.CAPTURE_INSTANCE_VERSION_STRATEGY_CREATE_DATE: + latest_instance = sorted(capture_instances, key=lambda x: x['create_date'])[-1] + elif capture_instance_version_strategy == options.CAPTURE_INSTANCE_VERSION_STRATEGY_REGEX: + latest_instance = sorted(capture_instances, key=lambda x: x['regex_matched_group'])[-1] + else: + raise Exception(f'Capture instance version strategy "{capture_instance_version_strategy}" not recognized.') + result[fq_name] = latest_instance + + logger.debug('Latest capture instance names determined by "%s" strategy: %s', capture_instance_version_strategy, + sorted([v['capture_instance_name'] for v in result.values()])) + + return result diff --git a/cdc_kafka/constants.py b/cdc_kafka/constants.py index 1a3a187..a08ac42 100644 --- a/cdc_kafka/constants.py +++ b/cdc_kafka/constants.py @@ -10,6 +10,9 @@ SLOW_TABLE_PROGRESS_HEARTBEAT_INTERVAL = datetime.timedelta(minutes=3) DB_CLOCK_SYNC_INTERVAL = datetime.timedelta(minutes=5) +SMALL_TABLE_THRESHOLD = 5_000_000 +MAX_AGE_TO_PRESUME_ADDED_COL_IS_NULL_SECONDS = 3600 + SQL_QUERY_TIMEOUT_SECONDS = 30 SQL_QUERY_INTER_RETRY_INTERVAL_SECONDS = 1 SQL_QUERY_RETRIES = 2 @@ -17,7 +20,7 @@ WATERMARK_STABILITY_CHECK_DELAY_SECS = 10 KAFKA_REQUEST_TIMEOUT_SECS = 15 KAFKA_FULL_FLUSH_TIMEOUT_SECS = 30 -KAFKA_CONFIG_RELOAD_DELAY_SECS = 5 +KAFKA_CONFIG_RELOAD_DELAY_SECS = 3 # General @@ -69,19 +72,10 @@ # Metadata column names and positions -OPERATION_POS = 0 OPERATION_NAME = '__operation' - -EVENT_TIME_POS = 1 EVENT_TIME_NAME = '__event_time' - -LSN_POS = 2 LSN_NAME = '__log_lsn' - -SEQVAL_POS = 3 SEQVAL_NAME = '__log_seqval' - -UPDATED_FIELDS_POS = 4 UPDATED_FIELDS_NAME = '__updated_fields' DB_LSN_COL_NAME = '__$start_lsn' diff --git a/cdc_kafka/helpers.py b/cdc_kafka/helpers.py index c8c2afe..d531f26 100644 --- a/cdc_kafka/helpers.py +++ b/cdc_kafka/helpers.py @@ -2,8 +2,25 @@ import confluent_kafka +from . import constants + # Helper function for loggers working with Kafka messages def format_coordinates(msg: confluent_kafka.Message) -> str: return f'{msg.topic()}:{msg.partition()}@{msg.offset()}, ' \ f'time {datetime.datetime.fromtimestamp(msg.timestamp()[1] / 1000)}' + + +def get_fq_change_table_name(capture_instance_name: str) -> str: + assert '.' not in capture_instance_name + capture_instance_name = capture_instance_name.strip(' []') + return f'{constants.CDC_DB_SCHEMA_NAME}.{capture_instance_name}_CT' + + +def get_capture_instance_name(change_table_name: str) -> str: + change_table_name = change_table_name.replace('[', '') + change_table_name = change_table_name.replace(']', '') + if change_table_name.startswith(constants.CDC_DB_SCHEMA_NAME + '.'): + change_table_name = change_table_name.replace(constants.CDC_DB_SCHEMA_NAME + '.', '') + assert change_table_name.endswith('_CT') + return change_table_name[:-3] diff --git a/cdc_kafka/kafka.py b/cdc_kafka/kafka.py index adc1a8a..745d0c4 100644 --- a/cdc_kafka/kafka.py +++ b/cdc_kafka/kafka.py @@ -78,7 +78,6 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots self._avro_decoders: Dict[int, Callable] = dict() self._schema_ids_to_names: Dict[int, str] = dict() self._delivery_callbacks: Dict[str, List[Callable]] = collections.defaultdict(list) - self._delivery_callbacks_finalized: bool = False self._global_produce_sequence_nbr: int = 0 self._cluster_metadata: Optional[confluent_kafka.admin.ClusterMetadata] = None self._last_full_flush_time: datetime.datetime = datetime.datetime.utcnow() @@ -92,7 +91,7 @@ def __init__(self, metrics_accumulator: 'accumulator.AccumulatorAbstract', boots def __enter__(self) -> 'KafkaClient': return self - def __exit__(self, exc_type, value, traceback) -> None: + def __exit__(self, *args) -> None: logger.info("Cleaning up Kafka resources...") self._consumer.close() self.flush(final=True) @@ -349,6 +348,16 @@ def register_schemas(self, topic_name: str, key_schema: Dict[str, Any], value_sc key_schema_compatibility_level: str = constants.DEFAULT_KEY_SCHEMA_COMPATIBILITY_LEVEL, value_schema_compatibility_level: str = constants.DEFAULT_VALUE_SCHEMA_COMPATIBILITY_LEVEL) \ -> Tuple[int, int]: + # TODO: it turns out that if you try to re-register a schema that was previously registered but later superseded + # (e.g. in the case of adding and then later deleting a column), the schema registry will accept that and return + # you the previously-registered schema ID without updating the `latest` version associated with the registry + # subject, or verifying that the change is Avro-compatible. It seems like the way to handle this, per + # https://github.com/confluentinc/schema-registry/issues/1685, would be to detect the condition and delete the + # subject-version-number of that schema before re-registering it. Since subject-version deletion is not + # available in the `CachedSchemaRegistryClient` we use here--and since this is a rare case--I'm explicitly + # choosing to punt on it for the moment. The Confluent lib does now have a newer `SchemaRegistryClient` class + # which supports subject-version deletion, but changing this code to use it appears to be a non-trivial task. + key_schema = confluent_kafka.avro.loads(json.dumps(key_schema)) value_schema = confluent_kafka.avro.loads(json.dumps(value_schema)) @@ -359,6 +368,7 @@ def register_schemas(self, topic_name: str, key_schema: Dict[str, Any], value_sc if (current_key_schema is None or current_key_schema != key_schema) and not self._disable_writing: logger.info('Key schema for subject %s does not exist or is outdated; registering now.', key_subject) key_schema_id = self._schema_registry.register(key_subject, key_schema) + logger.debug('Schema registered for subject %s: %s', key_subject, key_schema) if current_key_schema is None: time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) self._schema_registry.update_compatibility(key_schema_compatibility_level, key_subject) @@ -368,6 +378,7 @@ def register_schemas(self, topic_name: str, key_schema: Dict[str, Any], value_sc if (current_value_schema is None or current_value_schema != value_schema) and not self._disable_writing: logger.info('Value schema for subject %s does not exist or is outdated; registering now.', value_subject) value_schema_id = self._schema_registry.register(value_subject, value_schema) + logger.debug('Schema registered for subject %s: %s', value_subject, value_schema) if current_value_schema is None: time.sleep(constants.KAFKA_CONFIG_RELOAD_DELAY_SECS) self._schema_registry.update_compatibility(value_schema_compatibility_level, value_subject) diff --git a/cdc_kafka/main.py b/cdc_kafka/main.py index 9deb892..6c72a7a 100644 --- a/cdc_kafka/main.py +++ b/cdc_kafka/main.py @@ -6,13 +6,14 @@ import logging import re import time -from typing import Dict, Optional, List, Any, Iterable, Union, Tuple +from typing import Dict, Optional, List, Any, Tuple import pyodbc -from tabulate import tabulate from . import clock_sync, kafka, tracked_tables, constants, options, validation, change_index, progress_tracking, \ - sql_query_subprocess, sql_queries + sql_query_subprocess, sql_queries, helpers +from .build_startup_state import build_tracked_tables_from_cdc_metadata, determine_start_points_and_finalize_tables, \ + get_latest_capture_instances_by_fq_name from .metric_reporting import accumulator from typing import TYPE_CHECKING @@ -59,14 +60,14 @@ def run() -> None: for ci in capture_instances_by_fq_name.values()] tables: List[tracked_tables.TrackedTable] = build_tracked_tables_from_cdc_metadata( - db_conn, clock_syncer, metrics_accumulator, opts.topic_name_template, opts.snapshot_table_whitelist_regex, + db_conn, metrics_accumulator, opts.topic_name_template, opts.snapshot_table_whitelist_regex, opts.snapshot_table_blacklist_regex, opts.truncate_fields, capture_instance_names, opts.db_row_batch_size, - sql_query_processor) + opts.always_use_avro_longs, sql_query_processor) topic_to_source_table_map: Dict[str, str] = { t.topic_name: t.fq_name for t in tables} topic_to_change_table_map: Dict[str, str] = { - t.topic_name: f'{constants.CDC_DB_SCHEMA_NAME}.{t.change_table_name}' for t in tables} + t.topic_name: helpers.get_fq_change_table_name(t.capture_instance_name) for t in tables} capture_instance_to_topic_map: Dict[str, str] = { t.capture_instance_name: t.topic_name for t in tables} @@ -87,9 +88,9 @@ def run() -> None: ), metrics_accumulator.kafka_delivery_callback) determine_start_points_and_finalize_tables( - kafka_client, tables, progress_tracker, opts.lsn_gap_handling, opts.partition_count, - opts.replication_factor, opts.extra_topic_config, opts.run_validations, redo_snapshot_for_new_instance, - publish_duplicate_changes_from_new_instance, opts.report_progress_only) + kafka_client, db_conn, tables, progress_tracker, opts.lsn_gap_handling, opts.partition_count, + opts.replication_factor, opts.extra_topic_config, opts.always_use_avro_longs, opts.run_validations, + redo_snapshot_for_new_instance, publish_duplicate_changes_from_new_instance, opts.report_progress_only) if opts.report_progress_only: exit(0) @@ -292,243 +293,6 @@ def poll_periodic_tasks() -> bool: logger.info('Exiting due to external interrupt.') -# This pulls the "greatest" capture instance running for each source table, in the event there is more than one. -def get_latest_capture_instances_by_fq_name( - db_conn: pyodbc.Connection, capture_instance_version_strategy: str, capture_instance_version_regex: str, - table_whitelist_regex: str, table_blacklist_regex: str -) -> Dict[str, Dict[str, Any]]: - if capture_instance_version_strategy == options.CAPTURE_INSTANCE_VERSION_STRATEGY_REGEX \ - and not capture_instance_version_regex: - raise Exception('Please provide a capture_instance_version_regex when specifying the `regex` ' - 'capture_instance_version_strategy.') - result: Dict[str, Dict[str, Any]] = {} - fq_name_to_capture_instances: Dict[str, List[Dict[str, Any]]] = collections.defaultdict(list) - capture_instance_version_regex = capture_instance_version_regex and re.compile(capture_instance_version_regex) - table_whitelist_regex = table_whitelist_regex and re.compile(table_whitelist_regex, re.IGNORECASE) - table_blacklist_regex = table_blacklist_regex and re.compile(table_blacklist_regex, re.IGNORECASE) - - with db_conn.cursor() as cursor: - q, p = sql_queries.get_cdc_capture_instances_metadata() - cursor.execute(q) - for row in cursor.fetchall(): - fq_table_name = f'{row[0]}.{row[1]}' - - if table_whitelist_regex and not table_whitelist_regex.match(fq_table_name): - logger.debug('Table %s excluded by whitelist', fq_table_name) - continue - - if table_blacklist_regex and table_blacklist_regex.match(fq_table_name): - logger.debug('Table %s excluded by blacklist', fq_table_name) - continue - - if row[3] is None or row[4] is None: - logger.debug('Capture instance for %s appears to be brand-new; will evaluate again on ' - 'next pass', fq_table_name) - continue - - as_dict = { - 'fq_name': fq_table_name, - 'capture_instance_name': row[2], - 'start_lsn': row[3], - 'create_date': row[4], - } - if capture_instance_version_regex: - match = capture_instance_version_regex.match(row[1]) - as_dict['regex_matched_group'] = match and match.group(1) or '' - fq_name_to_capture_instances[as_dict['fq_name']].append(as_dict) - - for fq_name, capture_instances in fq_name_to_capture_instances.items(): - if capture_instance_version_strategy == options.CAPTURE_INSTANCE_VERSION_STRATEGY_CREATE_DATE: - latest_instance = sorted(capture_instances, key=lambda x: x['create_date'])[-1] - elif capture_instance_version_strategy == options.CAPTURE_INSTANCE_VERSION_STRATEGY_REGEX: - latest_instance = sorted(capture_instances, key=lambda x: x['regex_matched_group'])[-1] - else: - raise Exception(f'Capture instance version strategy "{capture_instance_version_strategy}" not recognized.') - result[fq_name] = latest_instance - - logger.debug('Latest capture instance names determined by "%s" strategy: %s', capture_instance_version_strategy, - sorted([v['capture_instance_name'] for v in result.values()])) - - return result - - -def build_tracked_tables_from_cdc_metadata( - db_conn: pyodbc.Connection, clock_syncer: 'clock_sync.ClockSync', metrics_accumulator: 'accumulator.Accumulator', - topic_name_template: str, snapshot_table_whitelist_regex: str, snapshot_table_blacklist_regex: str, - truncate_fields: Dict[str, int], capture_instance_names: List[str], db_row_batch_size: int, - sql_query_processor: 'sql_query_subprocess.SQLQueryProcessor' -) -> List[tracked_tables.TrackedTable]: - result: List[tracked_tables.TrackedTable] = [] - - truncate_fields = {k.lower(): v for k, v in truncate_fields.items()} - - snapshot_table_whitelist_regex = snapshot_table_whitelist_regex and re.compile( - snapshot_table_whitelist_regex, re.IGNORECASE) - snapshot_table_blacklist_regex = snapshot_table_blacklist_regex and re.compile( - snapshot_table_blacklist_regex, re.IGNORECASE) - - name_to_meta_fields: Dict[Tuple, List[Tuple]] = collections.defaultdict(list) - - with db_conn.cursor() as cursor: - q, p = sql_queries.get_cdc_tracked_tables_metadata(capture_instance_names) - cursor.execute(q) - for row in cursor.fetchall(): - # 0:4 gets schema name, table name, capture instance name, min captured LSN: - name_to_meta_fields[tuple(row[0:4])].append(row[4:]) - - for (schema_name, table_name, capture_instance_name, min_lsn), fields in name_to_meta_fields.items(): - fq_table_name = f'{schema_name}.{table_name}' - - can_snapshot = False - - if snapshot_table_whitelist_regex and snapshot_table_whitelist_regex.match(fq_table_name): - logger.debug('Table %s matched snapshotting whitelist', fq_table_name) - can_snapshot = True - - if snapshot_table_blacklist_regex and snapshot_table_blacklist_regex.match(fq_table_name): - logger.debug('Table %s matched snapshotting blacklist and will NOT be snapshotted', fq_table_name) - can_snapshot = False - - topic_name = topic_name_template.format( - schema_name=schema_name, table_name=table_name, capture_instance_name=capture_instance_name) - - tracked_table = tracked_tables.TrackedTable( - db_conn, clock_syncer, metrics_accumulator, sql_query_processor, schema_name, table_name, - capture_instance_name, topic_name, min_lsn, can_snapshot, db_row_batch_size) - - for (change_table_ordinal, column_name, sql_type_name, primary_key_ordinal, decimal_precision, - decimal_scale, is_identity) in fields: - truncate_after = truncate_fields.get(f'{schema_name}.{table_name}.{column_name}'.lower()) - tracked_table.append_field(tracked_tables.TrackedField( - column_name, sql_type_name, change_table_ordinal, primary_key_ordinal, decimal_precision, - decimal_scale, is_identity, truncate_after)) - - result.append(tracked_table) - - return result - - -def determine_start_points_and_finalize_tables( - kafka_client: kafka.KafkaClient, tables: Iterable[tracked_tables.TrackedTable], - progress_tracker: progress_tracking.ProgressTracker, lsn_gap_handling: str, - partition_count: int, replication_factor: int, extra_topic_config: Dict[str, Union[str, int]], - validation_mode: bool = False, redo_snapshot_for_new_instance: bool = False, - publish_duplicate_changes_from_new_instance: bool = False, report_progress_only: bool = False -) -> None: - topic_names: List[str] = [t.topic_name for t in tables] - - if validation_mode: - for table in tables: - table.snapshot_allowed = False - table.finalize_table(change_index.LOWEST_CHANGE_INDEX, {}, lsn_gap_handling) - return - - if report_progress_only: - watermarks_by_topic = [] - else: - watermarks_by_topic = kafka_client.get_topic_watermarks(topic_names) - first_check_watermarks_json = json.dumps(watermarks_by_topic) - - logger.info('Pausing briefly to ensure target topics are not receiving new messages from elsewhere...') - time.sleep(constants.WATERMARK_STABILITY_CHECK_DELAY_SECS) - - watermarks_by_topic = kafka_client.get_topic_watermarks(topic_names) - second_check_watermarks_json = json.dumps(watermarks_by_topic) - - if first_check_watermarks_json != second_check_watermarks_json: - raise Exception(f'Watermarks for one or more target topics changed between successive checks. ' - f'Another process may be producing to the topic(s). Bailing.\nFirst check: ' - f'{first_check_watermarks_json}\nSecond check: {second_check_watermarks_json}') - logger.debug('Topic watermarks: %s', second_check_watermarks_json) - - prior_progress_log_table_data = [] - prior_progress = progress_tracker.get_prior_progress_or_create_progress_topic() - - for table in tables: - snapshot_progress, changes_progress = None, None - - if not report_progress_only and table.topic_name not in watermarks_by_topic: # new topic; create it - if partition_count: - this_topic_partition_count = partition_count - else: - per_second = table.get_change_rows_per_second() - # one partition for each 10 rows/sec on average in the change table: - this_topic_partition_count = max(1, int(per_second / 10)) - if this_topic_partition_count > 100: - raise Exception( - f'Automatic topic creation would create %{this_topic_partition_count} partitions for topic ' - f'{table.topic_name} based on a change table rows per second rate of {per_second}. This ' - f'seems excessive, so the program is exiting to prevent overwhelming your Kafka cluster. ' - f'Look at setting PARTITION_COUNT to take manual control of this.') - logger.info('Creating topic %s with %s partition(s)', table.topic_name, this_topic_partition_count) - kafka_client.create_topic(table.topic_name, this_topic_partition_count, replication_factor, - extra_topic_config) - else: - snapshot_progress: Union[None, progress_tracking.ProgressEntry] = prior_progress.get( - (table.topic_name, constants.SNAPSHOT_ROWS_KIND)) - changes_progress: Union[None, progress_tracking.ProgressEntry] = prior_progress.get( - (table.topic_name, constants.CHANGE_ROWS_KIND)) - - fq_change_table_name = f'{constants.CDC_DB_SCHEMA_NAME}.{table.change_table_name}' - if snapshot_progress and (snapshot_progress.change_table_name != fq_change_table_name): - logger.info('Found prior snapshot progress into topic %s, but from an older capture instance ' - '(prior progress instance: %s; current instance: %s)', table.topic_name, - snapshot_progress.change_table_name, fq_change_table_name) - if redo_snapshot_for_new_instance: - if ddl_change_requires_new_snapshot(snapshot_progress.change_table_name, fq_change_table_name): - logger.info('Will start new snapshot.') - else: - logger.info('New snapshot does not appear to be required.') - snapshot_progress = None - else: - logger.info('Will NOT start new snapshot.') - - if changes_progress and (changes_progress.change_table_name != fq_change_table_name): - logger.info('Found prior change data progress into topic %s, but from an older capture instance ' - '(prior progress instance: %s; current instance: %s)', table.topic_name, - changes_progress.change_table_name, fq_change_table_name) - if publish_duplicate_changes_from_new_instance: - logger.info('Will republish any change rows duplicated by the new capture instance.') - changes_progress = None - else: - logger.info('Will NOT republish any change rows duplicated by the new capture instance.') - - starting_change_index = (changes_progress and changes_progress.change_index) \ - or change_index.LOWEST_CHANGE_INDEX - starting_snapshot_index = snapshot_progress and snapshot_progress.snapshot_index - - if report_progress_only: # elide schema registration - table.finalize_table(starting_change_index, starting_snapshot_index, options.LSN_GAP_HANDLING_IGNORE) - else: - table.finalize_table(starting_change_index, starting_snapshot_index, lsn_gap_handling, - kafka_client.register_schemas, progress_tracker.reset_progress, - progress_tracker.record_snapshot_completion) - - if not table.snapshot_allowed: - snapshot_state = '' - elif table.snapshot_complete: - snapshot_state = '' - elif table.last_read_key_for_snapshot_display is None: - snapshot_state = '' - else: - snapshot_state = f'From {table.last_read_key_for_snapshot_display}' - - prior_progress_log_table_data.append((table.capture_instance_name, table.fq_name, table.topic_name, - starting_change_index or '', snapshot_state)) - - headers = ('Capture instance name', 'Source table name', 'Topic name', 'From change table index', 'Snapshots') - table = tabulate(sorted(prior_progress_log_table_data), headers, tablefmt='fancy_grid') - - logger.info('Processing will proceed from the following positions based on the last message from each topic ' - 'and/or the snapshot progress committed in Kafka (NB: snapshot reads occur BACKWARDS from high to ' - 'low key column values):\n%s\n%s tables total.', table, len(prior_progress_log_table_data)) - - -# noinspection PyUnusedLocal -def ddl_change_requires_new_snapshot(old_ci_name: str, new_ci_name: str) -> bool: - return True # TODO: this is a stub for a next-up planned feature - - def should_terminate_due_to_capture_instance_change( db_conn: pyodbc.Connection, progress_tracker: progress_tracking.ProgressTracker, capture_instance_version_strategy: str, capture_instance_version_regex: str, @@ -571,8 +335,8 @@ def better_json_serialize(obj): new_ci_min_index = change_index.ChangeIndex(new_ci['start_lsn'], b'\x00' * 10, 0) if current_idx < new_ci_min_index: with db_conn.cursor() as cursor: - ci_table_name = f"[{constants.CDC_DB_SCHEMA_NAME}].[{current_ci['capture_instance_name']}_CT]" - cursor.execute(f"SELECT TOP 1 1 FROM {ci_table_name} WITH (NOLOCK)") + change_table_name = helpers.get_fq_change_table_name(current_ci['capture_instance_name']) + cursor.execute(f"SELECT TOP 1 1 FROM {change_table_name} WITH (NOLOCK)") has_rows = cursor.fetchval() is not None if has_rows: logger.info('Progress against existing capture instance ("%s") for table "%s" has reached index ' diff --git a/cdc_kafka/metric_reporting/accumulator.py b/cdc_kafka/metric_reporting/accumulator.py index 46c4ef2..c69c7f3 100644 --- a/cdc_kafka/metric_reporting/accumulator.py +++ b/cdc_kafka/metric_reporting/accumulator.py @@ -18,7 +18,7 @@ def reset_and_start(self) -> None: pass def end_and_get_values(self) -> metrics.Metrics: pass def register_sleep(self, sleep_time_seconds: float) -> None: pass def register_db_query(self, seconds_elapsed: float, db_query_kind: str, retrieved_row_count: int) -> None: pass - def register_kafka_produce(self, secs_elapsed: float, orig_value: Dict[str, Any], message_type: str) -> None: pass + def register_kafka_produce(self, seconds_elapsed: float, original_value: Dict[str, Any], message_type: str) -> None: pass def register_kafka_commit(self, seconds_elapsed: float) -> None: pass def kafka_delivery_callback(self, message_type: str, original_value: Dict[str, Any], produce_datetime: datetime.datetime, **_) -> None: pass diff --git a/cdc_kafka/metric_reporting/http_post_reporter.py b/cdc_kafka/metric_reporting/http_post_reporter.py index 9550a91..de47430 100644 --- a/cdc_kafka/metric_reporting/http_post_reporter.py +++ b/cdc_kafka/metric_reporting/http_post_reporter.py @@ -34,8 +34,8 @@ def _post(self, metrics_dict: Dict[str, Any]) -> None: else: body = json.dumps(metrics_dict, default=HttpPostReporter.json_serialize_datetimes) - resp = requests.post(self._url, data=body, headers=self._headers, timeout=10.0) try: + resp = requests.post(self._url, data=body, headers=self._headers, timeout=10.0) resp.raise_for_status() logger.debug('Posted metrics to %s with code %s and response: %s', self._url, resp.status_code, resp.text) except requests.exceptions.RequestException as e: diff --git a/cdc_kafka/options.py b/cdc_kafka/options.py index 9407d9e..74d691a 100644 --- a/cdc_kafka/options.py +++ b/cdc_kafka/options.py @@ -238,6 +238,18 @@ def get_options_and_metrics_reporters() -> Tuple[argparse.Namespace, List]: "then exits without changing any state. Can be handy for validating other configuration such " "as the regexes used to control which tables are followed and/or snapshotted.") + p.add_argument('--always-use-avro-longs', + type=str2bool, nargs='?', const=True, + default=str2bool(os.environ.get('ALWAYS_USE_AVRO_LONGS', False)), + help="Defaults to False. If set to True, Avro schemas produced/registered by this process will " + "use the Avro `long` type instead of the `int` type for fields corresponding to SQL Server " + "INT, SMALLINT, or TINYINT columns. This can be used to future-proof in cases where the column " + "size may need to be upgraded in the future, at the potential cost of increased storage or " + "memory space needs in consuming processes. Note that if this change is made for existing " + "topics, the schema registration attempt will violate Avro FORWARD compatibility checks (the " + "default used by this process), meaning that you may need to manually override the schema " + "registry compatibility level for any such topics first.") + p.add_argument('--db-row-batch-size', type=int, default=os.environ.get('DB_ROW_BATCH_SIZE', 2000), diff --git a/cdc_kafka/progress_tracking.py b/cdc_kafka/progress_tracking.py index b6bee96..80acc6b 100644 --- a/cdc_kafka/progress_tracking.py +++ b/cdc_kafka/progress_tracking.py @@ -206,7 +206,7 @@ def __init__(self, kafka_client: 'KafkaClient', progress_topic_name: str, def __enter__(self) -> 'ProgressTracker': return self - def __exit__(self, exc_type, value, traceback) -> None: + def __exit__(self, *args) -> None: logger.info("Committing final progress...") self.commit_progress(final=True) logger.info("Done.") @@ -242,10 +242,11 @@ def emit_changes_progress_heartbeat(self, topic_name: str, index: ChangeIndex) - def record_snapshot_completion(self, topic_name: str) -> None: self.commit_progress(final=True) + source_table_name = self._topic_to_source_table_map[topic_name] progress_entry = ProgressEntry( progress_kind=constants.SNAPSHOT_ROWS_KIND, topic_name=topic_name, - source_table_name=self._topic_to_source_table_map[topic_name], + source_table_name=source_table_name, change_table_name=self._topic_to_change_table_map[topic_name], last_ack_partition=None, last_ack_offset=None, @@ -253,6 +254,7 @@ def record_snapshot_completion(self, topic_name: str) -> None: change_index=None ) + logger.info('Recording snapshot completion for table %s into topic %s', source_table_name, topic_name) self._kafka_client.produce(self._progress_topic_name, progress_entry.key, self._progress_key_schema_id, progress_entry.value, self._progress_value_schema_id, constants.SNAPSHOT_PROGRESS_MESSAGE) diff --git a/cdc_kafka/replayer.py b/cdc_kafka/replayer.py new file mode 100755 index 0000000..fd4cce7 --- /dev/null +++ b/cdc_kafka/replayer.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 + +""" +A tool for populating a table that has been pre-created in a SQL Server DB based on data from a topic that was +produced by CDC-to-Kafka. + +One example use case would be to create a copy of an existing table (initially with a different name of course) on the +same DB from which a CDC-to-Kafka topic is produced, in order to rearrange indexes, types, etc. on the new copy of the +table. We're using this now to upgrade some INT columns to `BIGINT`s on a few tables that are nearing the 2^31 row +count. When the copy is ready, tables can be renamed so that applications begin using the new table. + +An example invocation follows. This assumes you have already created the Orders_copy table, with changes as desired, +in the DB, and that you have created the CdcTableCopier DB user there as well, presumably with limited permissions +to work only with the Orders_copy table. In this example we assume that 'OrderGuid' is a new column that exists +on the new _copy table only (perhaps populated by a default), and therefore we are not trying to sync that column +since it doesn't exist in the CDC feed/schema: + +./replayer.py \ + --replay-topic 'dbo_Orders_cdc' \ + --schema-registry-url 'http://localhost:8081' \ + --kafka-bootstrap-servers 'localhost:9092' \ + --target-db-conn-string 'DRIVER=ODBC Driver 18 for SQL Server; TrustServerCertificate=yes; SERVER=localhost; DATABASE=mydb; UID=CdcTableCopier; PWD=*****; APP=cdc-to-kafka-replayer;' \ + --target-db-schema 'dbo' \ + --target-db-table 'Orders_copy' \ + --cols-to-not-sync 'OrderGuid' \ + --consumer-group-name 'cdc-to-kafka-replayer-1' + +""" + +import argparse +import logging.config +import os +import time +from datetime import datetime +from typing import Set, Any, List, Dict + +import pyodbc + +from confluent_kafka import Consumer, KafkaError, TopicPartition, Message +from confluent_kafka.serialization import SerializationContext, MessageField +from confluent_kafka.schema_registry import SchemaRegistryClient +from confluent_kafka.schema_registry.avro import AvroDeserializer + +log_level = os.getenv('LOG_LEVEL', 'INFO').upper() + +logging.config.dictConfig({ + 'version': 1, + 'disable_existing_loggers': False, + 'loggers': { + __name__: { + 'handlers': ['console'], + 'level': log_level, + 'propagate': True, + }, + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'level': log_level, + 'formatter': 'simple', + }, + }, + 'formatters': { + 'simple': { + 'format': '%(asctime)s %(levelname)-8s [%(name)s:%(lineno)s] %(message)s', + }, + }, +}) + +logger = logging.getLogger(__name__) + + +DELETE_BATCH_SIZE = 2000 +MERGE_BATCH_SIZE = 5000 +MAX_BATCH_LATENCY_SECONDS = 10 + + +def commit_cb(err: KafkaError, tps: List[TopicPartition]): + if err is not None: + logger.error(f'Error committing offsets: {err}') + else: + logger.debug(f'Offsets committed for {tps}') + + +def format_coordinates(msg: Message) -> str: + return f'{msg.topic()}:{msg.partition()}@{msg.offset()}, ' \ + f'time {datetime.fromtimestamp(msg.timestamp()[1] / 1000)}' + + +def main() -> None: + p = argparse.ArgumentParser() + p.add_argument('--replay-topic', + default=os.environ.get('REPLAY_TOPIC')) + p.add_argument('--schema-registry-url', + default=os.environ.get('SCHEMA_REGISTRY_URL')) + p.add_argument('--kafka-bootstrap-servers', + default=os.environ.get('KAFKA_BOOTSTRAP_SERVERS')) + p.add_argument('--target-db-conn-string', + default=os.environ.get('TARGET_DB_CONN_STRING')) + p.add_argument('--target-db-schema', + default=os.environ.get('TARGET_DB_SCHEMA')) + p.add_argument('--target-db-table', + default=os.environ.get('TARGET_DB_TABLE')) + p.add_argument('--cols-to-not-sync', + default=os.environ.get('COLS_TO_NOT_SYNC', '')) + p.add_argument('--consumer-group-name', + default=os.environ.get('CONSUMER_GROUP_NAME', '')) + + opts = p.parse_args() + + logger.info("Starting CDC replayer.") + proc_start_time = time.perf_counter() + + if not (opts.schema_registry_url and opts.kafka_bootstrap_servers and opts.replay_topic and + opts.target_db_conn_string and opts.target_db_schema and opts.target_db_table and + opts.consumer_group_name): + raise Exception('Arguments replay_topic, target_db_conn_string, target_db_schema, target_db_table, ' + 'schema_registry_url, consumer_group_name, and kafka_bootstrap_servers are all required.') + + schema_registry_client: SchemaRegistryClient = SchemaRegistryClient({'url': opts.schema_registry_url}) + avro_deserializer: AvroDeserializer = AvroDeserializer(schema_registry_client) + consumer_conf = {'bootstrap.servers': opts.kafka_bootstrap_servers, + 'group.id': opts.consumer_group_name, + 'enable.auto.offset.store': False, + 'enable.auto.commit': False, + 'auto.offset.reset': "earliest", + 'on_commit': commit_cb} + consumer: Consumer = Consumer(consumer_conf) + consumer.subscribe([opts.replay_topic]) + logger.info(f'Subscribed to topic {opts.replay_topic}.') + + msg_ctr: int = 0 + del_cnt: int = 0 + merge_cnt: int = 0 + poll_time_acc: float = 0.0 + ser_time_acc: float = 0.0 + sql_time_acc: float = 0.0 + last_commit_time: datetime = datetime.now() + queued_deletes: Set[Any] = set() + queued_merges: Dict[Any, List[Any]] = {} + fq_target_name: str = f'[{opts.target_db_schema.strip()}].[{opts.target_db_table.strip()}]' + temp_table_name: str = f'#replayer_{opts.target_db_schema.strip()}_{opts.target_db_table.strip()}_bulk_load' + cols_to_not_sync: set[str] = set(opts.cols_to_not_sync.split(',')) + + try: + with pyodbc.connect(opts.target_db_conn_string) as db_conn: + with db_conn.cursor() as cursor: + cursor.execute(f''' + SELECT COLUMN_NAME + FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE + WHERE OBJECTPROPERTY(OBJECT_ID(CONSTRAINT_SCHEMA + '.' + QUOTENAME(CONSTRAINT_NAME)), 'IsPrimaryKey') = 1 + AND TABLE_NAME = ? AND TABLE_SCHEMA = ? + ''', opts.target_db_table, opts.target_db_schema) + + pk_fields = [r[0] for r in cursor.fetchall()] + if len(pk_fields) != 1: + raise Exception(f'Can only handle single-col PKs for now! Found: {pk_fields}') + pk_field = pk_fields[0] + + odbc_columns = tuple(cursor.columns( + schema=opts.target_db_schema, table=opts.target_db_table).fetchall()) + + param_types = [(x[4], x[6], x[8]) for x in odbc_columns] + fields: List[str] = [c[3] for c in odbc_columns if c[3] not in cols_to_not_sync] + quoted_fields_str: str = ', '.join((f'[{c}]' for c in fields)) + datetime_fields: set[str] = {c[3] for c in odbc_columns + if c[4] in (pyodbc.SQL_TYPE_TIMESTAMP,)} # TODO: more cases?? + + cursor.execute(f'DROP TABLE IF EXISTS {temp_table_name};') + # Yep, this looks weird--it's a hack to prevent SQL Server from copying over the IDENTITY property + # of any columns that have it whenever it creates the temp table. https://stackoverflow.com/a/57509258 + cursor.execute(f'SELECT TOP 0 * INTO {temp_table_name} FROM {fq_target_name} ' + f'UNION ALL SELECT * FROM {fq_target_name} WHERE 1 <> 1;') + for c in opts.cols_to_not_sync.split(','): + if not c.strip(): + continue + cursor.execute(f'ALTER TABLE {temp_table_name} DROP COLUMN {c.strip()};') + cursor.execute('SELECT TOP 1 1 FROM sys.columns WHERE object_id = OBJECT_ID(?) AND is_identity = 1', + fq_target_name) + set_identity_insert_if_needed = f'SET IDENTITY_INSERT {fq_target_name} ON; ' \ + if bool(cursor.fetchval()) else '' + + logger.info("Starting to consume messages.") + while True: + start_time = time.perf_counter() + msg = consumer.poll(0.5) + poll_time_acc += time.perf_counter() - start_time + + if len(queued_deletes) >= DELETE_BATCH_SIZE or len(queued_merges) >= MERGE_BATCH_SIZE or \ + (datetime.now() - last_commit_time).seconds > MAX_BATCH_LATENCY_SECONDS: + if queued_deletes: + q = None + try: + with db_conn.cursor() as cursor: + start_time = time.perf_counter() + q = f''' + DELETE FROM {fq_target_name} + WHERE [{pk_field}] IN ({'?,' * (len(queued_deletes) - 1)}?); + ''' + cursor.execute(q, list(queued_deletes)) + sql_time_acc += time.perf_counter() - start_time + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + except Exception as e: + print(q) + raise e + + logger.info('Deleted %s items in %s ms', len(queued_deletes), elapsed_ms) + if queued_merges: + q = None + try: + with db_conn.cursor() as cursor: + data = list(queued_merges.values()) + start_time = time.perf_counter() + cursor.fast_executemany = True + cursor.setinputsizes(param_types) + q = f''' + INSERT INTO {temp_table_name} ({quoted_fields_str}) + VALUES({'?,' * (len(fields) - 1)}?); + ''' + cursor.executemany(q, data) + q = f''' + {set_identity_insert_if_needed} + MERGE {fq_target_name} AS tgt + USING {temp_table_name} AS src + ON (tgt.[{pk_field}] = src.[{pk_field}]) + WHEN MATCHED THEN + UPDATE SET {", ".join([f'[{x}] = src.[{x}]' for x in fields if x != pk_field])} + WHEN NOT MATCHED THEN + INSERT ({quoted_fields_str}) VALUES (src.[{'], src.['.join(fields)}]); + TRUNCATE TABLE {temp_table_name}; + ''' + cursor.execute(q) + sql_time_acc += time.perf_counter() - start_time + elapsed_ms = int((time.perf_counter() - start_time) * 1000) + except Exception as e: + print(q) + raise e + + logger.info('Merged %s items in %s ms', len(queued_merges), elapsed_ms) + if queued_merges or queued_deletes: + queued_merges.clear() + queued_deletes.clear() + consumer.commit() + last_commit_time = datetime.now() + + if msg is None: + continue + + msg_ctr += 1 + + if msg_ctr % 100_000 == 0: + logger.info(f'Reached %s', format_coordinates(msg)) + + if msg.error(): + # noinspection PyProtectedMember + if msg.error().code() == KafkaError._PARTITION_EOF: + break + else: + raise Exception(msg.error()) + + start_time = time.perf_counter() + # noinspection PyTypeChecker + key: Dict[str, Any] = avro_deserializer( + msg.key(), SerializationContext(msg.topic(), MessageField.KEY)) + # noinspection PyArgumentList,PyTypeChecker + val: Dict[str, Any] = avro_deserializer( + msg.value(), SerializationContext(msg.topic(), MessageField.VALUE)) + ser_time_acc += time.perf_counter() - start_time + + pk_val = key[pk_field] + + if val is None or val['__operation'] == 'Delete': + queued_deletes.add(pk_val) + queued_merges.pop(pk_val, None) + del_cnt += 1 + else: + vals = [] + for f in fields: + if f in datetime_fields and val[f] is not None: + vals.append(datetime.fromisoformat(val[f])) + else: + vals.append(val[f]) + queued_merges[pk_val] = vals + queued_deletes.discard(pk_val) + merge_cnt += 1 + + consumer.store_offsets(msg) + except KeyboardInterrupt: + pass + except Exception as e: + logger.exception(e) + + logger.info(f'Processed {msg_ctr} messages total, {del_cnt} deletes, {merge_cnt} merges.') + overall_time = time.perf_counter() - proc_start_time + logger.info(f'Total times:\nKafka poll: {poll_time_acc:.2f}s\nAvro deserialize: {ser_time_acc:.2f}s\n' + f'SQL execution: {sql_time_acc:.2f}s\nOverall: {overall_time:.2f}s') + logger.info("Closing consumer.") + consumer.close() + + +if __name__ == "__main__": + main() diff --git a/cdc_kafka/sql_queries.py b/cdc_kafka/sql_queries.py index 7c78a53..9b095f4 100644 --- a/cdc_kafka/sql_queries.py +++ b/cdc_kafka/sql_queries.py @@ -35,10 +35,11 @@ def get_cdc_tracked_tables_metadata(capture_instance_names: List[str]) -> \ , cc.column_ordinal AS change_table_ordinal , cc.column_name AS column_name , cc.column_type AS sql_type_name + , cc.is_computed AS is_computed , ic.index_ordinal AS primary_key_ordinal , sc.precision AS decimal_precision , sc.scale AS decimal_scale - , sc.is_identity AS is_identity + , sc.is_nullable AS is_nullable FROM [{constants.CDC_DB_SCHEMA_NAME}].[change_tables] AS ct INNER JOIN [{constants.CDC_DB_SCHEMA_NAME}].[captured_columns] AS cc ON (ct.object_id = cc.object_id) @@ -59,32 +60,60 @@ def get_latest_cdc_entry_time() -> Tuple[str, List[Tuple[int, int, Optional[int] ''', [] -def get_change_rows_per_second(change_table_name: str) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: +def get_change_rows_per_second(fq_change_table_name: str) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: return f''' -- cdc-to-kafka: get_change_rows_per_second SELECT ISNULL(COUNT(*) / NULLIF(DATEDIFF(second, MIN(ltm.tran_end_time), MAX(ltm.tran_end_time)), 0), 0) -FROM [{constants.CDC_DB_SCHEMA_NAME}].[{change_table_name}] AS ct WITH (NOLOCK) +FROM {fq_change_table_name} AS ct WITH (NOLOCK) INNER JOIN [{constants.CDC_DB_SCHEMA_NAME}].[lsn_time_mapping] AS ltm WITH (NOLOCK) ON ct.__$start_lsn = ltm.start_lsn ''', [] -def get_change_table_index_cols(change_table_name: str) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: +def get_change_table_index_cols() -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: return f''' -- cdc-to-kafka: get_change_table_index_cols SELECT COL_NAME(ic.object_id, ic.column_id) FROM sys.indexes AS i INNER JOIN sys.index_columns AS ic ON i.object_id = ic.object_id AND i.index_id = ic.index_id -WHERE i.object_id = OBJECT_ID('{constants.CDC_DB_SCHEMA_NAME}.{change_table_name}') AND type_desc = 'CLUSTERED' +WHERE i.object_id = OBJECT_ID(?) AND type_desc = 'CLUSTERED' ORDER BY key_ordinal - ''', [] + ''', [(pyodbc.SQL_VARCHAR, 255, None)] def get_date() -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: return 'SELECT GETDATE()', [] +def get_indexed_cols() -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: + return f''' +-- cdc-to-kafka: get_indexed_cols +SELECT DISTINCT c.[name] +FROM sys.index_columns AS ic +INNER JOIN sys.indexes AS i + ON ic.[object_id] = i.[object_id] + AND ic.[index_id] = i.[index_id] +INNER JOIN sys.columns AS c + ON ic.[object_id] = c.[object_id] + AND ic.[column_id] = c.[column_id] +WHERE ic.[object_id] = OBJECT_ID(?) + AND ic.[key_ordinal] = 1 + AND i.[is_disabled] = 0 + AND i.[type] != 0 + AND i.has_filter = 0 + ''', [(pyodbc.SQL_VARCHAR, 255, None)] + + +def get_ddl_history_for_capture_table() -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: + return f''' +-- cdc-to-kafka: get_ddl_history_for_capture_table +SELECT ddl_command, DATEDIFF(second, ddl_time, GETDATE()) AS age_seconds +FROM [{constants.CDC_DB_SCHEMA_NAME}].[ddl_history] +WHERE object_id = OBJECT_ID(?) AND required_column_update = 0 + ''', [(pyodbc.SQL_VARCHAR, 255, None)] + + def get_table_count(schema_name: str, table_name: str, pk_cols: Tuple[str], - odbc_columns: Tuple[Tuple]) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: + odbc_columns: Tuple[pyodbc.Row, ...]) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: declarations, where_spec, params = _get_snapshot_query_bits(pk_cols, odbc_columns, ('>=', '<=')) return f''' @@ -94,11 +123,23 @@ def get_table_count(schema_name: str, table_name: str, pk_cols: Tuple[str], ; SELECT COUNT(*) -FROM [{schema_name}].[{table_name}] +FROM [{schema_name}].[{table_name}] WITH (NOLOCK) WHERE {where_spec} ''', params +def get_table_rowcount_bounded(table_fq_name: str, max_count: int) -> \ + Tuple[str, List[Tuple[int, int, Optional[int]]]]: + assert max_count > 0 + return f''' +-- cdc-to-kafka: get_table_rowcount_bounded +SELECT COUNT(*) FROM ( + SELECT TOP {max_count} 1 AS nbr + FROM {table_fq_name} WITH (NOLOCK) +) AS ctr + ''', [] + + def get_max_key_value(schema_name: str, table_name: str, pk_cols: Tuple[str]) -> \ Tuple[str, List[Tuple[int, int, Optional[int]]]]: select_spec = ", ".join([f'[{x}]' for x in pk_cols]) @@ -121,7 +162,7 @@ def get_min_key_value(schema_name: str, table_name: str, pk_cols: Tuple[str]) -> ''', [] -def get_change_table_count_by_operation(change_table_name: str) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: +def get_change_table_count_by_operation(fq_change_table_name: str) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: return f''' -- cdc-to-kafka: get_change_table_count_by_operation DECLARE @@ -133,7 +174,7 @@ def get_change_table_count_by_operation(change_table_name: str) -> Tuple[str, Li SELECT COUNT(*) , __$operation AS op -FROM [{constants.CDC_DB_SCHEMA_NAME}].[{change_table_name}] WITH (NOLOCK) +FROM {fq_change_table_name} WITH (NOLOCK) WHERE __$operation != 3 AND ( __$start_lsn < @LSN @@ -148,7 +189,16 @@ def get_max_lsn() -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: return 'SELECT sys.fn_cdc_get_max_lsn()', [] -def get_change_rows(batch_size: int, change_table_name: str, field_names: Iterable[str], +def get_max_lsn_for_change_table(fq_change_table_name: str) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: + return f''' +-- cdc-to-kafka: get_max_lsn_for_change_table +SELECT TOP 1 __$start_lsn, __$command_id, __$seqval, __$operation +FROM {fq_change_table_name} +ORDER BY __$start_lsn DESC, __$command_id DESC, __$seqval DESC, __$operation DESC + ''', [] + + +def get_change_rows(batch_size: int, fq_change_table_name: str, field_names: Iterable[str], ct_index_cols: Iterable[str]) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: # You may feel tempted to change or simplify this query. TREAD CAREFULLY. There was a lot of iterating here to # craft something that would not induce SQL Server to resort to a full index scan. If you change it, run some @@ -169,13 +219,13 @@ def get_change_rows(batch_size: int, change_table_name: str, field_names: Iterab WITH ct AS ( SELECT * - FROM [{constants.CDC_DB_SCHEMA_NAME}].[{change_table_name}] AS ct WITH (NOLOCK) + FROM {fq_change_table_name} AS ct WITH (NOLOCK) WHERE ct.__$start_lsn = @LSN AND ct.__$seqval > @SEQ AND ct.__$start_lsn <= @MAX_LSN UNION ALL SELECT * - FROM [{constants.CDC_DB_SCHEMA_NAME}].[{change_table_name}] AS ct WITH (NOLOCK) + FROM {fq_change_table_name} AS ct WITH (NOLOCK) WHERE ct.__$start_lsn > @LSN AND ct.__$start_lsn <= @MAX_LSN ) SELECT TOP ({batch_size}) @@ -195,7 +245,7 @@ def get_change_rows(batch_size: int, change_table_name: str, field_names: Iterab def get_snapshot_rows( batch_size: int, schema_name: str, table_name: str, field_names: Collection[str], removed_field_names: Collection[str], pk_cols: Collection[str], first_read: bool, - odbc_columns: Collection[Tuple]) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: + odbc_columns: Tuple[pyodbc.Row, ...]) -> Tuple[str, List[Tuple[int, int, Optional[int]]]]: select_cols = [] for fn in field_names: if fn in removed_field_names: @@ -232,15 +282,15 @@ def get_snapshot_rows( ''', params -def _get_snapshot_query_bits(pk_cols: Collection[str], odbc_columns: Iterable[Tuple], comparators: Iterable[str]) \ - -> Tuple[str, str, List[Tuple[int, int, Optional[int]]]]: +def _get_snapshot_query_bits(pk_cols: Collection[str], odbc_columns: Tuple[pyodbc.Row, ...], + comparators: Iterable[str]) -> Tuple[str, str, List[Tuple[int, int, Optional[int]]]]: # For multi-column primary keys, this builds a WHERE clause of the following form, assuming # for example a PK on (field_a, field_b, field_c): # WHERE (field_a < @K0) # OR (field_a = @K0 AND field_b < @K1) # OR (field_a = @K0 AND field_b = @K1 AND field_c < @K2) - # You may find it odd that this query (as well as the change data query) has DECLARE statements in it. + # You may find it odd that this query (as well as the change data query) has `DECLARE` statements in it. # Why not just pass the parameters with the query like usual? We found that in composite-key cases, # the need to pass the parameter for the bounding value of the non-last column(s) more than once caused # SQL Server to treat those as different values (even though they were actually the same), and this diff --git a/cdc_kafka/sql_query_subprocess.py b/cdc_kafka/sql_query_subprocess.py index 1319326..91da9b0 100644 --- a/cdc_kafka/sql_query_subprocess.py +++ b/cdc_kafka/sql_query_subprocess.py @@ -6,7 +6,7 @@ import re import struct import time -from typing import Any, Tuple, Dict, Optional, Iterable, NamedTuple +from typing import Any, Tuple, Dict, Optional, Iterable, NamedTuple, List import pyodbc @@ -28,7 +28,7 @@ class SQLQueryResult(NamedTuple): reflected_query_request_metadata: Any query_executed_utc: datetime.datetime query_took_sec: float - result_rows: Tuple[Tuple] + result_rows: List[pyodbc.Row] query_params: Optional[Tuple[Any]] @@ -58,7 +58,7 @@ def __enter__(self) -> 'SQLQueryProcessor': logger.debug("SQL query subprocess started.") return self - def __exit__(self, exc_type, value, traceback) -> None: + def __exit__(self, *args) -> None: if not self._ended: self._stop_event.set() self._check_if_ended() diff --git a/cdc_kafka/tracked_tables.py b/cdc_kafka/tracked_tables.py index 8081892..6a2c6f5 100644 --- a/cdc_kafka/tracked_tables.py +++ b/cdc_kafka/tracked_tables.py @@ -7,10 +7,9 @@ import bitarray import pyodbc -from . import avro_from_sql, constants, change_index, options, sql_queries, sql_query_subprocess, parsed_row +from . import avro_from_sql, constants, change_index, options, sql_queries, sql_query_subprocess, parsed_row, helpers if TYPE_CHECKING: - from . import clock_sync from .metric_reporting import accumulator logger = logging.getLogger(__name__) @@ -20,17 +19,16 @@ class TrackedField(object): def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, primary_key_ordinal: int, - decimal_precision: int, decimal_scale: int, is_identity: bool, + decimal_precision: int, decimal_scale: int, force_avro_long: bool, truncate_after: Optional[int] = None) -> None: self.name: str = name self.sql_type_name: str = sql_type_name self.change_table_ordinal: int = change_table_ordinal - self.is_identity: bool = is_identity self.primary_key_ordinal: int = primary_key_ordinal self.avro_schema: Dict[str, Any] = avro_from_sql.avro_schema_from_sql_type( - name, sql_type_name, decimal_precision, decimal_scale, False) + name, sql_type_name, decimal_precision, decimal_scale, False, force_avro_long) self.nullable_avro_schema: Dict[str, Any] = avro_from_sql.avro_schema_from_sql_type( - name, sql_type_name, decimal_precision, decimal_scale, True) + name, sql_type_name, decimal_precision, decimal_scale, True, force_avro_long) self.transform_fn: Optional[Callable[[Any], Any]] = avro_from_sql.avro_transform_fn_from_sql_type(sql_type_name) if truncate_after is not None: @@ -46,13 +44,11 @@ def __init__(self, name: str, sql_type_name: str, change_table_ordinal: int, pri class TrackedTable(object): - def __init__(self, db_conn: pyodbc.Connection, clock_syncer: 'clock_sync.ClockSync', - metrics_accumulator: 'accumulator.Accumulator', + def __init__(self, db_conn: pyodbc.Connection, metrics_accumulator: 'accumulator.Accumulator', sql_query_processor: sql_query_subprocess.SQLQueryProcessor, schema_name: str, table_name: str, capture_instance_name: str, topic_name: str, min_lsn: bytes, snapshot_allowed: bool, db_row_batch_size: int) -> None: self._db_conn: pyodbc.Connection = db_conn - self._clock_syncer: clock_sync.ClockSync = clock_syncer self._metrics_accumulator: 'accumulator.Accumulator' = metrics_accumulator self._sql_query_processor: sql_query_subprocess.SQLQueryProcessor = sql_query_processor @@ -62,7 +58,6 @@ def __init__(self, db_conn: pyodbc.Connection, clock_syncer: 'clock_sync.ClockSy self.topic_name: str = topic_name self.snapshot_allowed: bool = snapshot_allowed self.fq_name: str = f'{schema_name}.{table_name}' - self.change_table_name: str = f'{capture_instance_name}_CT' self.db_row_batch_size: int = db_row_batch_size # Most of the below properties are not set until sometime after `finalize_table` is called: @@ -82,7 +77,7 @@ def __init__(self, db_conn: pyodbc.Connection, clock_syncer: 'clock_sync.ClockSy self._key_field_source_table_ordinals: Tuple[int] = tuple() self._value_field_names: List[str] = [] self._last_read_key_for_snapshot: Optional[Tuple] = None - self._odbc_columns: Tuple[Tuple] = tuple() + self._odbc_columns: Tuple[pyodbc.Row, ...] = tuple() self._change_rows_query: Optional[str] = None self._change_rows_query_param_types: List[Tuple[int, int, int]] = [] self._snapshot_rows_query: Optional[str] = None @@ -116,7 +111,8 @@ def get_source_table_count(self, low_key: Tuple, high_key: Tuple) -> int: def get_change_table_counts(self, highest_change_index: change_index.ChangeIndex) -> Tuple[int, int, int]: with self._db_conn.cursor() as cursor: deletes, inserts, updates = 0, 0, 0 - q, p = sql_queries.get_change_table_count_by_operation(self.change_table_name) + q, p = sql_queries.get_change_table_count_by_operation( + helpers.get_fq_change_table_name(self.capture_instance_name)) cursor.setinputsizes(p) cursor.execute(q, (highest_change_index.lsn, highest_change_index.seqval, highest_change_index.operation)) for row in cursor.fetchall(): @@ -134,6 +130,7 @@ def get_change_table_counts(self, highest_change_index: change_index.ChangeIndex # 'Finalizing' mostly means doing the things we can't do until we know all the table's fields have been added def finalize_table( self, start_after_change_table_index: change_index.ChangeIndex, + prior_change_table_max_index: Optional[change_index.ChangeIndex], start_from_key_for_snapshot: Optional[Dict[str, Any]], lsn_gap_handling: str, schema_id_getter: Callable[[str, Dict[str, Any], Dict[str, Any]], Tuple[int, int]] = None, progress_reset_fn: Callable[[str, str], None] = None, @@ -149,7 +146,10 @@ def finalize_table( f'(0x{self.min_lsn.hex()}) is later than the log position we need to start from based on the prior ' f'progress LSN stored in Kafka (0x{start_after_change_table_index.lsn.hex()}).') - if lsn_gap_handling == options.LSN_GAP_HANDLING_IGNORE: + if prior_change_table_max_index and prior_change_table_max_index <= start_after_change_table_index: + logger.info('%s Proceeding anyway, because it appears that no new entries are present in the prior ' + 'capture instance with an LSN higher than the last changes sent to Kafka.', msg) + elif lsn_gap_handling == options.LSN_GAP_HANDLING_IGNORE: logger.warning('%s Proceeding anyway since lsn_gap_handling is set to "%s"!', msg, options.LSN_GAP_HANDLING_IGNORE) elif lsn_gap_handling == options.LSN_GAP_HANDLING_BEGIN_NEW_SNAPSHOT: @@ -204,8 +204,9 @@ def finalize_table( self.value_schema) with self._db_conn.cursor() as cursor: - q, p = sql_queries.get_change_table_index_cols(self.change_table_name) - cursor.execute(q) + q, p = sql_queries.get_change_table_index_cols() + cursor.setinputsizes(p) + cursor.execute(q, helpers.get_fq_change_table_name(self.capture_instance_name)) change_table_clustered_idx_cols = [r[0] for r in cursor.fetchall()] self._odbc_columns = tuple(cursor.columns(schema=self.schema_name, table=self.table_name).fetchall()) @@ -215,12 +216,13 @@ def finalize_table( found_metadata_cols = [c for c in change_table_clustered_idx_cols if c in required_metadata_cols_ordered] if found_metadata_cols != required_metadata_cols_ordered: - raise Exception(f'The index for change table {self.change_table_name} did not contain the expected ' - f'CDC metadata columns, or contained them in the wrong order. Index columns found ' - f'were: {change_table_clustered_idx_cols}') + raise Exception(f'The index for change table {helpers.get_fq_change_table_name(self.capture_instance_name)} ' + f'did not contain the expected CDC metadata columns, or contained them in the wrong order. ' + f'Index columns found were: {change_table_clustered_idx_cols}') self._change_rows_query, self._change_rows_query_param_types = sql_queries.get_change_rows( - self.db_row_batch_size, self.change_table_name, self._value_field_names, change_table_clustered_idx_cols) + self.db_row_batch_size, helpers.get_fq_change_table_name(self.capture_instance_name), + self._value_field_names, change_table_clustered_idx_cols) if not self.snapshot_allowed: self.snapshot_complete = True @@ -270,7 +272,7 @@ def finalize_table( else: raise Exception( f"Snapshotting was requested for table {self.fq_name}, but it does not appear to have a primary " - f"key (which is required for snapshotting at this time). You can get past this error by adding" + f"key (which is required for snapshotting at this time). You can get past this error by adding " f"the table to the snapshot blacklist") def enqueue_snapshot_query(self) -> None: @@ -358,11 +360,11 @@ def retrieve_changes_query_results(self) -> Generator[parsed_row.ParsedRow, None def get_change_rows_per_second(self) -> int: with self._db_conn.cursor() as cursor: - q, p = sql_queries.get_change_rows_per_second(self.change_table_name) + q, p = sql_queries.get_change_rows_per_second(helpers.get_fq_change_table_name(self.capture_instance_name)) cursor.execute(q) return cursor.fetchval() or 0 - def _parse_db_row(self, db_row: Tuple) -> parsed_row.ParsedRow: + def _parse_db_row(self, db_row: pyodbc.Row) -> parsed_row.ParsedRow: operation_id, event_db_time, lsn, seqval, update_mask, *table_cols = db_row operation_name = constants.CDC_OPERATION_ID_TO_NAME[operation_id] diff --git a/requirements.txt b/requirements.txt index c6d4825..ea96756 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ avro==1.11.1 -bitarray==2.7.3 -confluent-kafka==2.0.2 -fastavro==1.7.3 +bitarray==2.7.4 +confluent-kafka==2.1.1 +fastavro==1.7.4 Jinja2==3.1.2 -pyodbc==4.0.35 -requests==2.28.2 -sentry-sdk==1.17.0 +pyodbc==4.0.39 +requests==2.31.0 +sentry-sdk==1.24.0 sortedcontainers==2.4.0 tabulate==0.9.0