From f56287e0dccff67dfc5486eb642dbc1244695a09 Mon Sep 17 00:00:00 2001 From: Tishj Date: Thu, 2 Jan 2025 13:42:37 +0100 Subject: [PATCH] make 'current' and 'segment_tree' private, create methods for initializing them, making sure the lock is grabbed when the tree to scan from is assigned --- .../table/column_data_checkpointer.hpp | 8 +- .../duckdb/storage/table/scan_state.hpp | 20 +++-- .../duckdb/storage/table/segment_tree.hpp | 4 +- src/storage/table/array_column_data.cpp | 4 +- src/storage/table/column_data.cpp | 88 +++++++++++-------- .../table/column_data_checkpointer.cpp | 24 ++--- src/storage/table/row_group.cpp | 42 +++++++-- src/storage/table/standard_column_data.cpp | 6 +- src/storage/table/struct_column_data.cpp | 4 +- 9 files changed, 129 insertions(+), 71 deletions(-) diff --git a/src/include/duckdb/storage/table/column_data_checkpointer.hpp b/src/include/duckdb/storage/table/column_data_checkpointer.hpp index 2bdc41dd8b25..fb9d13fb61d5 100644 --- a/src/include/duckdb/storage/table/column_data_checkpointer.hpp +++ b/src/include/duckdb/storage/table/column_data_checkpointer.hpp @@ -27,14 +27,14 @@ class ColumnDataCheckpointer { RowGroup &GetRowGroup(); ColumnCheckpointState &GetCheckpointState(); - void Checkpoint(const column_segment_vector_t &nodes); + void Checkpoint(const ColumnSegmentTree &tree); void FinalizeCheckpoint(column_segment_vector_t &&nodes); CompressionFunction &GetCompressionFunction(CompressionType type); private: - void ScanSegments(const column_segment_vector_t &nodes, const std::function &callback); - unique_ptr DetectBestCompressionMethod(const column_segment_vector_t &nodes, idx_t &compression_idx); - void WriteToDisk(const column_segment_vector_t &nodes); + void ScanSegments(const ColumnSegmentTree &tree, const std::function &callback); + unique_ptr DetectBestCompressionMethod(const ColumnSegmentTree &tree, idx_t &compression_idx); + void WriteToDisk(const ColumnSegmentTree &tree); bool HasChanges(const column_segment_vector_t &nodes); void WritePersistentSegments(column_segment_vector_t nodes); diff --git a/src/include/duckdb/storage/table/scan_state.hpp b/src/include/duckdb/storage/table/scan_state.hpp index 1131768f5693..b96c6d8f65c2 100644 --- a/src/include/duckdb/storage/table/scan_state.hpp +++ b/src/include/duckdb/storage/table/scan_state.hpp @@ -93,13 +93,17 @@ struct ColumnScanState { //! Move ONLY this state forward by "count" rows (i.e. not the child states) void NextInternal(idx_t count); + void InitializeSegmentTree(const ColumnSegmentTree &tree, SegmentLock &lock); + void InitializeSegmentTree(const ColumnSegmentTree &tree); + void InitializeSegment(const ColumnSegment &segment); + void ResetSegment(); + bool HasSegment() const; + const ColumnSegment &GetSegment() const; + idx_t RemainingInSegment() const; + public: SegmentLock owned_lock; - SegmentLock &lock; - //! The column segment that is currently being scanned - optional_ptr current; - //! Column segment tree - optional_ptr segment_tree; + reference lock; //! The current row index of the scan idx_t row_index = 0; //! The internal row index (i.e. the position of the SegmentScanState) @@ -121,6 +125,12 @@ struct ColumnScanState { vector scan_child_column; //! Contains TableScan level config for scanning optional_ptr scan_options; + +private: + //! The column segment that is currently being scanned + optional_ptr current; + //! Column segment tree + optional_ptr segment_tree; }; struct ColumnFetchState { diff --git a/src/include/duckdb/storage/table/segment_tree.hpp b/src/include/duckdb/storage/table/segment_tree.hpp index 51bcdad7149e..65bd4c13beeb 100644 --- a/src/include/duckdb/storage/table/segment_tree.hpp +++ b/src/include/duckdb/storage/table/segment_tree.hpp @@ -80,11 +80,11 @@ class SegmentTree { return MoveSegments(l); } - const vector> &ReferenceSegments(SegmentLock &l) { + const vector> &ReferenceSegments(SegmentLock &l) const { LoadAllSegments(l); return nodes; } - const vector> &ReferenceSegments() { + const vector> &ReferenceSegments() const { auto l = Lock(); return ReferenceSegments(l); } diff --git a/src/storage/table/array_column_data.cpp b/src/storage/table/array_column_data.cpp index 7a94b84129cb..8960c6d33f1d 100644 --- a/src/storage/table/array_column_data.cpp +++ b/src/storage/table/array_column_data.cpp @@ -42,7 +42,7 @@ void ArrayColumnData::InitializeScan(ColumnScanState &state) const { D_ASSERT(state.child_states.size() == 2); state.row_index = 0; - state.current = nullptr; + state.ResetSegment(); validity.InitializeScan(state.child_states[0]); @@ -60,7 +60,7 @@ void ArrayColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row } state.row_index = row_idx; - state.current = nullptr; + state.ResetSegment(); // initialize the validity segment validity.InitializeScanWithOffset(state.child_states[0], row_idx); diff --git a/src/storage/table/column_data.cpp b/src/storage/table/column_data.cpp index a3e957b8cc80..0e300a2bcfcc 100644 --- a/src/storage/table/column_data.cpp +++ b/src/storage/table/column_data.cpp @@ -88,9 +88,15 @@ idx_t ColumnData::GetMaxEntry() { } void ColumnData::InitializeScan(ColumnScanState &state, SegmentLock &lock) const { - state.current = data.GetRootSegment(lock); - state.segment_tree = &data; - state.row_index = state.current ? state.current->start : 0; + state.InitializeSegmentTree(data, lock); + auto root_segment = data.GetRootSegment(lock); + if (root_segment) { + state.InitializeSegment(*root_segment); + state.row_index = root_segment->start; + } else { + state.row_index = 0; + state.ResetSegment(); + } state.internal_index = state.row_index; state.initialized = false; state.scan_state.reset(); @@ -102,10 +108,11 @@ void ColumnData::InitializeScan(ColumnScanState &state) const { } void ColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx, SegmentLock &lock) const { - state.current = data.GetSegment(lock, row_idx); - state.segment_tree = &data; + state.InitializeSegmentTree(data, lock); + auto segment = data.GetSegment(lock, row_idx); + state.InitializeSegment(*segment); state.row_index = row_idx; - state.internal_index = state.current->start; + state.internal_index = segment->start; state.initialized = false; state.scan_state.reset(); state.last_offset = 0; @@ -125,7 +132,7 @@ ScanVectorType ColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_ return ScanVectorType::SCAN_FLAT_VECTOR; } // check if the current segment has enough data remaining - idx_t remaining_in_segment = state.current->start + state.current->count - state.row_index; + idx_t remaining_in_segment = state.RemainingInSegment(); if (remaining_in_segment < scan_count) { // there is not enough data remaining in the current segment so we need to scan across segments // we need flat vectors here @@ -135,13 +142,13 @@ ScanVectorType ColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_ } void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t remaining) const { - auto current_segment = scan_state.current; - if (!current_segment) { + if (!scan_state.HasSegment()) { return; } + optional_ptr current_segment = scan_state.GetSegment(); if (!scan_state.initialized) { // need to prefetch for the current segment if we have not yet initialized the scan for this segment - scan_state.current->InitializePrefetch(prefetch_state, scan_state); + current_segment->InitializePrefetch(prefetch_state, scan_state); } idx_t row_index = scan_state.row_index; while (remaining > 0) { @@ -161,18 +168,19 @@ void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanSta void ColumnData::BeginScanVectorInternal(ColumnScanState &state) const { state.previous_states.clear(); + D_ASSERT(state.HasSegment()); + auto ¤t = state.GetSegment(); if (!state.initialized) { - D_ASSERT(state.current); - state.current->InitializeScan(state); - state.internal_index = state.current->start; + current.InitializeScan(state); + state.internal_index = current.start; state.initialized = true; } - D_ASSERT(data.HasSegment(state.lock, state.current.get())); + D_ASSERT(data.HasSegment(state.lock, &state.GetSegment())); D_ASSERT(state.internal_index <= state.row_index); if (state.internal_index < state.row_index) { - state.current->Skip(state); + current.Skip(state); } - D_ASSERT(state.current->type == type); + D_ASSERT(current.type == type); } idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, ScanVectorType scan_type) const { @@ -182,19 +190,19 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai BeginScanVectorInternal(state); idx_t initial_remaining = remaining; while (remaining > 0) { - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); - idx_t scan_count = MinValue(remaining, state.current->start + state.current->count - state.row_index); + auto ¤t = state.GetSegment(); + D_ASSERT(state.row_index >= current.start && state.row_index <= current.start + current.count); + idx_t scan_count = MinValue(remaining, current.start + current.count - state.row_index); idx_t result_offset = initial_remaining - remaining; if (scan_count > 0) { if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < scan_count; i++) { ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, - result_offset + i); + current.FetchRow(fetch_state, UnsafeNumericCast(state.row_index + i), result, + result_offset + i); } } else { - state.current->Scan(state, scan_count, result, result_offset, scan_type); + current.Scan(state, scan_count, result, result_offset, scan_type); } state.row_index += scan_count; @@ -202,16 +210,15 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai } if (remaining > 0) { - auto next = data.GetNextSegment(state.current.get()); + auto next = data.GetNextSegment(¤t); if (!next) { break; } state.previous_states.emplace_back(std::move(state.scan_state)); - state.current = next; - state.current->InitializeScan(state); + state.InitializeSegment(*next); + next->InitializeScan(state); state.segment_checked = false; - D_ASSERT(state.row_index >= state.current->start && - state.row_index <= state.current->start + state.current->count); + D_ASSERT(state.row_index >= next->start && state.row_index <= next->start + next->count); } } state.internal_index = state.row_index; @@ -221,17 +228,18 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t target_count, const SelectionVector &sel, idx_t sel_count) const { BeginScanVectorInternal(state); - if (state.current->start + state.current->count - state.row_index < target_count) { + auto ¤t = state.GetSegment(); + if (current.start + current.count - state.row_index < target_count) { throw InternalException("ColumnData::SelectVector should be able to fetch everything from one segment"); } if (state.scan_options && state.scan_options->force_fetch_row) { for (idx_t i = 0; i < sel_count; i++) { auto source_idx = sel.get_index(i); ColumnFetchState fetch_state; - state.current->FetchRow(fetch_state, UnsafeNumericCast(state.row_index + source_idx), result, i); + current.FetchRow(fetch_state, UnsafeNumericCast(state.row_index + source_idx), result, i); } } else { - state.current->Select(state, target_count, result, sel, sel_count); + current.Select(state, target_count, result, sel, sel_count); } state.row_index += target_count; state.internal_index = state.row_index; @@ -240,10 +248,11 @@ void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t targ void ColumnData::FilterVector(ColumnScanState &state, Vector &result, idx_t target_count, SelectionVector &sel, idx_t &sel_count, const TableFilter &filter) const { BeginScanVectorInternal(state); - if (state.current->start + state.current->count - state.row_index < target_count) { + auto ¤t = state.GetSegment(); + if (current.start + current.count - state.row_index < target_count) { throw InternalException("ColumnData::Filter should be able to fetch everything from one segment"); } - state.current->Filter(state, target_count, result, sel, sel_count, filter); + current.Filter(state, target_count, result, sel, sel_count, filter); state.row_index += target_count; state.internal_index = state.row_index; } @@ -397,14 +406,15 @@ FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilt if (state.segment_checked) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } - if (!state.current) { + if (!state.HasSegment()) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } + auto ¤t = state.GetSegment(); state.segment_checked = true; FilterPropagateResult prune_result; { lock_guard l(stats_lock); - prune_result = filter.CheckStatistics(state.current->stats.statistics); + prune_result = filter.CheckStatistics(current.stats.statistics); if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) { return FilterPropagateResult::NO_PRUNING_POSSIBLE; } @@ -530,8 +540,10 @@ idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { // perform the fetch within the segment state.row_index = start + ((UnsafeNumericCast(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE); - state.current = data.GetSegment(state.lock, state.row_index); - state.internal_index = state.current->start; + auto segment = data.GetSegment(state.row_index); + state.InitializeSegmentTree(data); + state.InitializeSegment(*segment); + state.internal_index = segment->start; return ScanVector(state, result, STANDARD_VECTOR_SIZE, ScanVectorType::SCAN_FLAT_VECTOR); } @@ -549,7 +561,7 @@ void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state, void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids, idx_t update_count) { Vector base_vector(type); - ColumnScanState state(data.Lock()); // FIXME: SEGMENT LOCK + ColumnScanState state; auto fetch_count = Fetch(state, row_ids[0], base_vector); base_vector.Flatten(fetch_count); @@ -654,7 +666,7 @@ unique_ptr ColumnData::Checkpoint(RowGroup &row_group, Co } ColumnDataCheckpointer checkpointer(*this, row_group, *checkpoint_state, checkpoint_info); - checkpointer.Checkpoint(nodes); + checkpointer.Checkpoint(data); checkpointer.FinalizeCheckpoint(data.MoveSegments(l)); // reset the compression function diff --git a/src/storage/table/column_data_checkpointer.cpp b/src/storage/table/column_data_checkpointer.cpp index b1e55cbba94f..217eb3098199 100644 --- a/src/storage/table/column_data_checkpointer.cpp +++ b/src/storage/table/column_data_checkpointer.cpp @@ -38,13 +38,15 @@ ColumnCheckpointState &ColumnDataCheckpointer::GetCheckpointState() { return state; } -void ColumnDataCheckpointer::ScanSegments(const column_segment_vector_t &nodes, +void ColumnDataCheckpointer::ScanSegments(const ColumnSegmentTree &tree, const std::function &callback) { Vector scan_vector(intermediate.GetType(), nullptr); + auto &nodes = tree.ReferenceSegments(state.lock); for (auto &node : nodes) { auto &segment = *node.node; - ColumnScanState scan_state(state.lock); - scan_state.current = &segment; + ColumnScanState scan_state; + scan_state.InitializeSegmentTree(tree, state.lock); + scan_state.InitializeSegment(segment); segment.InitializeScan(scan_state); for (idx_t base_row_index = 0; base_row_index < segment.count; base_row_index += STANDARD_VECTOR_SIZE) { @@ -94,7 +96,7 @@ CompressionType ForceCompression(vector> &comp return found ? compression_type : CompressionType::COMPRESSION_AUTO; } -unique_ptr ColumnDataCheckpointer::DetectBestCompressionMethod(const column_segment_vector_t &nodes, +unique_ptr ColumnDataCheckpointer::DetectBestCompressionMethod(const ColumnSegmentTree &tree, idx_t &compression_idx) { D_ASSERT(!compression_functions.empty()); auto &config = DBConfig::GetConfig(GetDatabase()); @@ -120,7 +122,7 @@ unique_ptr ColumnDataCheckpointer::DetectBestCompressionMethod(con } // scan over all the segments and run the analyze step - ScanSegments(nodes, [&](Vector &scan_vector, idx_t count) { + ScanSegments(tree, [&](Vector &scan_vector, idx_t count) { for (idx_t i = 0; i < compression_functions.size(); i++) { if (!compression_functions[i]) { continue; @@ -172,13 +174,14 @@ unique_ptr ColumnDataCheckpointer::DetectBestCompressionMethod(con return state; } -void ColumnDataCheckpointer::WriteToDisk(const column_segment_vector_t &nodes) { +void ColumnDataCheckpointer::WriteToDisk(const ColumnSegmentTree &tree) { // there were changes or transient segments // we need to rewrite the column segments to disk // first we check the current segments // if there are any persistent segments, we will mark their old block ids as modified // since the segments will be rewritten their old on disk data is no longer required + auto &nodes = tree.ReferenceSegments(state.lock); for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) { auto segment = nodes[segment_idx].node.get(); segment->CommitDropSegment(); @@ -187,7 +190,7 @@ void ColumnDataCheckpointer::WriteToDisk(const column_segment_vector_t &nodes) { // now we need to write our segment // we will first run an analyze step that determines which compression function to use idx_t compression_idx; - auto analyze_state = DetectBestCompressionMethod(nodes, compression_idx); + auto analyze_state = DetectBestCompressionMethod(tree, compression_idx); if (!analyze_state) { throw FatalException("No suitable compression/storage method found to store column"); @@ -198,7 +201,7 @@ void ColumnDataCheckpointer::WriteToDisk(const column_segment_vector_t &nodes) { auto compress_state = best_function->init_compression(*this, std::move(analyze_state)); ScanSegments( - nodes, [&](Vector &scan_vector, idx_t count) { best_function->compress(*compress_state, scan_vector, count); }); + tree, [&](Vector &scan_vector, idx_t count) { best_function->compress(*compress_state, scan_vector, count); }); best_function->compress_finalize(*compress_state); } @@ -236,12 +239,13 @@ void ColumnDataCheckpointer::WritePersistentSegments(column_segment_vector_t nod } } -void ColumnDataCheckpointer::Checkpoint(const column_segment_vector_t &nodes) { +void ColumnDataCheckpointer::Checkpoint(const ColumnSegmentTree &tree) { + auto &nodes = tree.ReferenceSegments(state.lock); D_ASSERT(!nodes.empty()); has_changes = HasChanges(nodes); // first check if any of the segments have changes if (has_changes) { - WriteToDisk(nodes); + WriteToDisk(tree); } } diff --git a/src/storage/table/row_group.cpp b/src/storage/table/row_group.cpp index 3e4bb2b59a0e..b2792dfca247 100644 --- a/src/storage/table/row_group.cpp +++ b/src/storage/table/row_group.cpp @@ -165,6 +165,38 @@ void RowGroup::InitializeEmpty(const vector &types) { } } +void ColumnScanState::InitializeSegmentTree(const ColumnSegmentTree &tree, SegmentLock &segment_lock) { + segment_tree = &tree; + lock = segment_lock; +} + +void ColumnScanState::InitializeSegmentTree(const ColumnSegmentTree &tree) { + owned_lock = tree.Lock(); + InitializeSegmentTree(tree, owned_lock); +} + +idx_t ColumnScanState::RemainingInSegment() const { + return current->start + current->count - row_index; +} + +bool ColumnScanState::HasSegment() const { + return current != nullptr; +} + +const ColumnSegment &ColumnScanState::GetSegment() const { + D_ASSERT(HasSegment()); + return *current; +} + +void ColumnScanState::ResetSegment() { + current = nullptr; +} + +void ColumnScanState::InitializeSegment(const ColumnSegment &segment) { + D_ASSERT(segment_tree && segment_tree->HasSegment(lock.get(), &segment)); + current = &segment; +} + void ColumnScanState::Initialize(const LogicalType &type, const vector &children, optional_ptr options) { // Register the options in the state @@ -255,7 +287,7 @@ bool RowGroup::InitializeScanWithOffset(CollectionScanState &state, idx_t vector column_data.InitializeScanWithOffset(state.column_scans[i], row_number); state.column_scans[i].scan_options = &state.GetOptions(); } else { - state.column_scans[i].current = nullptr; + state.column_scans[i].ResetSegment(); } } return true; @@ -282,7 +314,7 @@ bool RowGroup::InitializeScan(CollectionScanState &state) const { column_data.InitializeScan(state.column_scans[i]); state.column_scans[i].scan_options = &state.GetOptions(); } else { - state.column_scans[i].current = nullptr; + state.column_scans[i].ResetSegment(); } } return true; @@ -479,12 +511,12 @@ bool RowGroup::CheckZonemapSegments(CollectionScanState &state) const { // check zone map segment. auto &column_scan_state = state.column_scans[column_idx]; - auto current_segment = column_scan_state.current; - if (!current_segment) { + if (!column_scan_state.HasSegment()) { // no segment to skip continue; } - idx_t target_row = current_segment->start + current_segment->count; + auto ¤t_segment = column_scan_state.GetSegment(); + idx_t target_row = current_segment.start + current_segment.count; if (target_row >= state.max_row) { target_row = state.max_row; } diff --git a/src/storage/table/standard_column_data.cpp b/src/storage/table/standard_column_data.cpp index 845a44146aee..f5a6036e44e7 100644 --- a/src/storage/table/standard_column_data.cpp +++ b/src/storage/table/standard_column_data.cpp @@ -141,7 +141,7 @@ void StandardColumnData::RevertAppend(row_t start_row) { idx_t StandardColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) { // fetch validity mask if (state.child_states.empty()) { - ColumnScanState child_state(validity.data.Lock()); // FIXME: SEGMENT LOCK + ColumnScanState child_state; child_state.scan_options = state.scan_options; state.child_states.push_back(std::move(child_state)); } @@ -263,10 +263,10 @@ unique_ptr StandardColumnData::Checkpoint(RowGroup &row_g D_ASSERT(!validity_nodes.empty()); ColumnDataCheckpointer base_checkpointer(*this, row_group, checkpoint_state, checkpoint_info); - base_checkpointer.Checkpoint(base_nodes); + base_checkpointer.Checkpoint(data); ColumnDataCheckpointer validity_checkpointer(validity, row_group, validity_state, checkpoint_info); - validity_checkpointer.Checkpoint(validity_nodes); + validity_checkpointer.Checkpoint(validity.data); base_checkpointer.FinalizeCheckpoint(data.MoveSegments(base_lock)); validity_checkpointer.FinalizeCheckpoint(validity.data.MoveSegments(validity_lock)); diff --git a/src/storage/table/struct_column_data.cpp b/src/storage/table/struct_column_data.cpp index 7d9ea76b0816..c0402f04e3d8 100644 --- a/src/storage/table/struct_column_data.cpp +++ b/src/storage/table/struct_column_data.cpp @@ -54,7 +54,7 @@ void StructColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnS void StructColumnData::InitializeScan(ColumnScanState &state) const { D_ASSERT(state.child_states.size() == sub_columns.size() + 1); state.row_index = 0; - state.current = nullptr; + state.ResetSegment(); // initialize the validity segment validity.InitializeScan(state.child_states[0]); @@ -71,7 +71,7 @@ void StructColumnData::InitializeScan(ColumnScanState &state) const { void StructColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx) const { D_ASSERT(state.child_states.size() == sub_columns.size() + 1); state.row_index = row_idx; - state.current = nullptr; + state.ResetSegment(); // initialize the validity segment validity.InitializeScanWithOffset(state.child_states[0], row_idx);