Skip to content

Commit

Permalink
Fix schema and table name quoting in SQL queries (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
woodlee authored Jun 2, 2023
1 parent f42f18e commit ea09c01
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
8 changes: 5 additions & 3 deletions cdc_kafka/build_startup_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def determine_start_points_and_finalize_tables(
cursor.execute("SELECT 1 FROM sys.tables WHERE object_id = OBJECT_ID(?)",
changes_progress.change_table_name)
if cursor.fetchval() is not None:
q, p = sql_queries.get_max_lsn_for_change_table(changes_progress.change_table_name)
q, p = sql_queries.get_max_lsn_for_change_table(
helpers.quote_name(changes_progress.change_table_name))
cursor.execute(q)
res = cursor.fetchone()
if res:
Expand Down Expand Up @@ -272,7 +273,8 @@ def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, old_capture_ins
source_table_fq_name, added_col_name)
return True

q, p = sql_queries.get_table_rowcount_bounded(source_table_fq_name, constants.SMALL_TABLE_THRESHOLD)
quoted_fq_name = helpers.quote_name(source_table_fq_name)
q, p = sql_queries.get_table_rowcount_bounded(quoted_fq_name, constants.SMALL_TABLE_THRESHOLD)
cursor.execute(q)
bounded_row_count = cursor.fetchval()
logger.debug('Bounded row count for %s was: %s', source_table_fq_name, bounded_row_count)
Expand All @@ -287,7 +289,7 @@ def ddl_change_requires_new_snapshot(db_conn: pyodbc.Connection, old_capture_ins

for added_col_name in added_col_names:
if table_is_small or added_col_name in indexed_cols:
cursor.execute(f"SELECT TOP 1 1 FROM {source_table_fq_name} WITH (NOLOCK) "
cursor.execute(f"SELECT TOP 1 1 FROM {quoted_fq_name} WITH (NOLOCK) "
f"WHERE [{added_col_name}] IS NOT NULL")
if cursor.fetchval() is not None:
logger.info('Requiring re-snapshot for %s because a direct scan of newly-tracked column %s '
Expand Down
7 changes: 7 additions & 0 deletions cdc_kafka/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,10 @@ def get_capture_instance_name(change_table_name: str) -> str:
change_table_name = change_table_name.replace(constants.CDC_DB_SCHEMA_NAME + '.', '')
assert change_table_name.endswith('_CT')
return change_table_name[:-3]


def quote_name(name: str) -> str:
name = name.replace('[', '')
name = name.replace(']', '')
parts = name.split('.')
return '.'.join([f"[{p}]" for p in parts])
3 changes: 2 additions & 1 deletion cdc_kafka/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def better_json_serialize(obj):
new_ci_min_index = change_index.ChangeIndex(new_ci['start_lsn'], b'\x00' * 10, 0)
if current_idx < new_ci_min_index:
with db_conn.cursor() as cursor:
change_table_name = helpers.get_fq_change_table_name(current_ci['capture_instance_name'])
change_table_name = helpers.quote_name(
helpers.get_fq_change_table_name(current_ci['capture_instance_name']))
cursor.execute(f"SELECT TOP 1 1 FROM {change_table_name} WITH (NOLOCK)")
has_rows = cursor.fetchval() is not None
if has_rows:
Expand Down
7 changes: 4 additions & 3 deletions cdc_kafka/tracked_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_change_table_counts(self, highest_change_index: change_index.ChangeIndex
with self._db_conn.cursor() as cursor:
deletes, inserts, updates = 0, 0, 0
q, p = sql_queries.get_change_table_count_by_operation(
helpers.get_fq_change_table_name(self.capture_instance_name))
helpers.quote_name(helpers.get_fq_change_table_name(self.capture_instance_name)))
cursor.setinputsizes(p)
cursor.execute(q, (highest_change_index.lsn, highest_change_index.seqval, highest_change_index.operation))
for row in cursor.fetchall():
Expand Down Expand Up @@ -221,7 +221,7 @@ def finalize_table(
f'Index columns found were: {change_table_clustered_idx_cols}')

self._change_rows_query, self._change_rows_query_param_types = sql_queries.get_change_rows(
self.db_row_batch_size, helpers.get_fq_change_table_name(self.capture_instance_name),
self.db_row_batch_size, helpers.quote_name(helpers.get_fq_change_table_name(self.capture_instance_name)),
self._value_field_names, change_table_clustered_idx_cols)

if not self.snapshot_allowed:
Expand Down Expand Up @@ -360,7 +360,8 @@ def retrieve_changes_query_results(self) -> Generator[parsed_row.ParsedRow, None

def get_change_rows_per_second(self) -> int:
with self._db_conn.cursor() as cursor:
q, p = sql_queries.get_change_rows_per_second(helpers.get_fq_change_table_name(self.capture_instance_name))
q, p = sql_queries.get_change_rows_per_second(
helpers.quote_name(helpers.get_fq_change_table_name(self.capture_instance_name)))
cursor.execute(q)
return cursor.fetchval() or 0

Expand Down

0 comments on commit ea09c01

Please sign in to comment.