Skip to content

Commit

Permalink
make 'current' and 'segment_tree' private, create methods for initial…
Browse files Browse the repository at this point in the history
…izing them, making sure the lock is grabbed when the tree to scan from is assigned
  • Loading branch information
Tishj committed Jan 2, 2025
1 parent 6706701 commit f56287e
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 71 deletions.
8 changes: 4 additions & 4 deletions src/include/duckdb/storage/table/column_data_checkpointer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@ class ColumnDataCheckpointer {
RowGroup &GetRowGroup();
ColumnCheckpointState &GetCheckpointState();

void Checkpoint(const column_segment_vector_t &nodes);
void Checkpoint(const ColumnSegmentTree &tree);
void FinalizeCheckpoint(column_segment_vector_t &&nodes);
CompressionFunction &GetCompressionFunction(CompressionType type);

private:
void ScanSegments(const column_segment_vector_t &nodes, const std::function<void(Vector &, idx_t)> &callback);
unique_ptr<AnalyzeState> DetectBestCompressionMethod(const column_segment_vector_t &nodes, idx_t &compression_idx);
void WriteToDisk(const column_segment_vector_t &nodes);
void ScanSegments(const ColumnSegmentTree &tree, const std::function<void(Vector &, idx_t)> &callback);
unique_ptr<AnalyzeState> DetectBestCompressionMethod(const ColumnSegmentTree &tree, idx_t &compression_idx);
void WriteToDisk(const ColumnSegmentTree &tree);
bool HasChanges(const column_segment_vector_t &nodes);
void WritePersistentSegments(column_segment_vector_t nodes);

Expand Down
20 changes: 15 additions & 5 deletions src/include/duckdb/storage/table/scan_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,17 @@ struct ColumnScanState {
//! Move ONLY this state forward by "count" rows (i.e. not the child states)
void NextInternal(idx_t count);

void InitializeSegmentTree(const ColumnSegmentTree &tree, SegmentLock &lock);
void InitializeSegmentTree(const ColumnSegmentTree &tree);
void InitializeSegment(const ColumnSegment &segment);
void ResetSegment();
bool HasSegment() const;
const ColumnSegment &GetSegment() const;
idx_t RemainingInSegment() const;

public:
SegmentLock owned_lock;
SegmentLock &lock;
//! The column segment that is currently being scanned
optional_ptr<const ColumnSegment> current;
//! Column segment tree
optional_ptr<const ColumnSegmentTree> segment_tree;
reference<SegmentLock> lock;
//! The current row index of the scan
idx_t row_index = 0;
//! The internal row index (i.e. the position of the SegmentScanState)
Expand All @@ -121,6 +125,12 @@ struct ColumnScanState {
vector<bool> scan_child_column;
//! Contains TableScan level config for scanning
optional_ptr<TableScanOptions> scan_options;

private:
//! The column segment that is currently being scanned
optional_ptr<const ColumnSegment> current;
//! Column segment tree
optional_ptr<const ColumnSegmentTree> segment_tree;
};

struct ColumnFetchState {
Expand Down
4 changes: 2 additions & 2 deletions src/include/duckdb/storage/table/segment_tree.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ class SegmentTree {
return MoveSegments(l);
}

const vector<SegmentNode<T>> &ReferenceSegments(SegmentLock &l) {
const vector<SegmentNode<T>> &ReferenceSegments(SegmentLock &l) const {
LoadAllSegments(l);
return nodes;
}
const vector<SegmentNode<T>> &ReferenceSegments() {
const vector<SegmentNode<T>> &ReferenceSegments() const {
auto l = Lock();
return ReferenceSegments(l);
}
Expand Down
4 changes: 2 additions & 2 deletions src/storage/table/array_column_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void ArrayColumnData::InitializeScan(ColumnScanState &state) const {
D_ASSERT(state.child_states.size() == 2);

state.row_index = 0;
state.current = nullptr;
state.ResetSegment();

validity.InitializeScan(state.child_states[0]);

Expand All @@ -60,7 +60,7 @@ void ArrayColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row
}

state.row_index = row_idx;
state.current = nullptr;
state.ResetSegment();

// initialize the validity segment
validity.InitializeScanWithOffset(state.child_states[0], row_idx);
Expand Down
88 changes: 50 additions & 38 deletions src/storage/table/column_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,15 @@ idx_t ColumnData::GetMaxEntry() {
}

void ColumnData::InitializeScan(ColumnScanState &state, SegmentLock &lock) const {
state.current = data.GetRootSegment(lock);
state.segment_tree = &data;
state.row_index = state.current ? state.current->start : 0;
state.InitializeSegmentTree(data, lock);
auto root_segment = data.GetRootSegment(lock);
if (root_segment) {
state.InitializeSegment(*root_segment);
state.row_index = root_segment->start;
} else {
state.row_index = 0;
state.ResetSegment();
}
state.internal_index = state.row_index;
state.initialized = false;
state.scan_state.reset();
Expand All @@ -102,10 +108,11 @@ void ColumnData::InitializeScan(ColumnScanState &state) const {
}

void ColumnData::InitializeScanWithOffset(ColumnScanState &state, idx_t row_idx, SegmentLock &lock) const {
state.current = data.GetSegment(lock, row_idx);
state.segment_tree = &data;
state.InitializeSegmentTree(data, lock);
auto segment = data.GetSegment(lock, row_idx);
state.InitializeSegment(*segment);
state.row_index = row_idx;
state.internal_index = state.current->start;
state.internal_index = segment->start;
state.initialized = false;
state.scan_state.reset();
state.last_offset = 0;
Expand All @@ -125,7 +132,7 @@ ScanVectorType ColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_
return ScanVectorType::SCAN_FLAT_VECTOR;
}
// check if the current segment has enough data remaining
idx_t remaining_in_segment = state.current->start + state.current->count - state.row_index;
idx_t remaining_in_segment = state.RemainingInSegment();
if (remaining_in_segment < scan_count) {
// there is not enough data remaining in the current segment so we need to scan across segments
// we need flat vectors here
Expand All @@ -135,13 +142,13 @@ ScanVectorType ColumnData::GetVectorScanType(ColumnScanState &state, idx_t scan_
}

void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanState &scan_state, idx_t remaining) const {
auto current_segment = scan_state.current;
if (!current_segment) {
if (!scan_state.HasSegment()) {
return;
}
optional_ptr<const ColumnSegment> current_segment = scan_state.GetSegment();
if (!scan_state.initialized) {
// need to prefetch for the current segment if we have not yet initialized the scan for this segment
scan_state.current->InitializePrefetch(prefetch_state, scan_state);
current_segment->InitializePrefetch(prefetch_state, scan_state);
}
idx_t row_index = scan_state.row_index;
while (remaining > 0) {
Expand All @@ -161,18 +168,19 @@ void ColumnData::InitializePrefetch(PrefetchState &prefetch_state, ColumnScanSta

void ColumnData::BeginScanVectorInternal(ColumnScanState &state) const {
state.previous_states.clear();
D_ASSERT(state.HasSegment());
auto &current = state.GetSegment();
if (!state.initialized) {
D_ASSERT(state.current);
state.current->InitializeScan(state);
state.internal_index = state.current->start;
current.InitializeScan(state);
state.internal_index = current.start;
state.initialized = true;
}
D_ASSERT(data.HasSegment(state.lock, state.current.get()));
D_ASSERT(data.HasSegment(state.lock, &state.GetSegment()));
D_ASSERT(state.internal_index <= state.row_index);
if (state.internal_index < state.row_index) {
state.current->Skip(state);
current.Skip(state);
}
D_ASSERT(state.current->type == type);
D_ASSERT(current.type == type);
}

idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remaining, ScanVectorType scan_type) const {
Expand All @@ -182,36 +190,35 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai
BeginScanVectorInternal(state);
idx_t initial_remaining = remaining;
while (remaining > 0) {
D_ASSERT(state.row_index >= state.current->start &&
state.row_index <= state.current->start + state.current->count);
idx_t scan_count = MinValue<idx_t>(remaining, state.current->start + state.current->count - state.row_index);
auto &current = state.GetSegment();
D_ASSERT(state.row_index >= current.start && state.row_index <= current.start + current.count);
idx_t scan_count = MinValue<idx_t>(remaining, current.start + current.count - state.row_index);
idx_t result_offset = initial_remaining - remaining;
if (scan_count > 0) {
if (state.scan_options && state.scan_options->force_fetch_row) {
for (idx_t i = 0; i < scan_count; i++) {
ColumnFetchState fetch_state;
state.current->FetchRow(fetch_state, UnsafeNumericCast<row_t>(state.row_index + i), result,
result_offset + i);
current.FetchRow(fetch_state, UnsafeNumericCast<row_t>(state.row_index + i), result,
result_offset + i);
}
} else {
state.current->Scan(state, scan_count, result, result_offset, scan_type);
current.Scan(state, scan_count, result, result_offset, scan_type);
}

state.row_index += scan_count;
remaining -= scan_count;
}

if (remaining > 0) {
auto next = data.GetNextSegment(state.current.get());
auto next = data.GetNextSegment(&current);
if (!next) {
break;
}
state.previous_states.emplace_back(std::move(state.scan_state));
state.current = next;
state.current->InitializeScan(state);
state.InitializeSegment(*next);
next->InitializeScan(state);
state.segment_checked = false;
D_ASSERT(state.row_index >= state.current->start &&
state.row_index <= state.current->start + state.current->count);
D_ASSERT(state.row_index >= next->start && state.row_index <= next->start + next->count);
}
}
state.internal_index = state.row_index;
Expand All @@ -221,17 +228,18 @@ idx_t ColumnData::ScanVector(ColumnScanState &state, Vector &result, idx_t remai
void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t target_count, const SelectionVector &sel,
idx_t sel_count) const {
BeginScanVectorInternal(state);
if (state.current->start + state.current->count - state.row_index < target_count) {
auto &current = state.GetSegment();
if (current.start + current.count - state.row_index < target_count) {
throw InternalException("ColumnData::SelectVector should be able to fetch everything from one segment");
}
if (state.scan_options && state.scan_options->force_fetch_row) {
for (idx_t i = 0; i < sel_count; i++) {
auto source_idx = sel.get_index(i);
ColumnFetchState fetch_state;
state.current->FetchRow(fetch_state, UnsafeNumericCast<row_t>(state.row_index + source_idx), result, i);
current.FetchRow(fetch_state, UnsafeNumericCast<row_t>(state.row_index + source_idx), result, i);
}
} else {
state.current->Select(state, target_count, result, sel, sel_count);
current.Select(state, target_count, result, sel, sel_count);
}
state.row_index += target_count;
state.internal_index = state.row_index;
Expand All @@ -240,10 +248,11 @@ void ColumnData::SelectVector(ColumnScanState &state, Vector &result, idx_t targ
void ColumnData::FilterVector(ColumnScanState &state, Vector &result, idx_t target_count, SelectionVector &sel,
idx_t &sel_count, const TableFilter &filter) const {
BeginScanVectorInternal(state);
if (state.current->start + state.current->count - state.row_index < target_count) {
auto &current = state.GetSegment();
if (current.start + current.count - state.row_index < target_count) {
throw InternalException("ColumnData::Filter should be able to fetch everything from one segment");
}
state.current->Filter(state, target_count, result, sel, sel_count, filter);
current.Filter(state, target_count, result, sel, sel_count, filter);
state.row_index += target_count;
state.internal_index = state.row_index;
}
Expand Down Expand Up @@ -397,14 +406,15 @@ FilterPropagateResult ColumnData::CheckZonemap(ColumnScanState &state, TableFilt
if (state.segment_checked) {
return FilterPropagateResult::NO_PRUNING_POSSIBLE;
}
if (!state.current) {
if (!state.HasSegment()) {
return FilterPropagateResult::NO_PRUNING_POSSIBLE;
}
auto &current = state.GetSegment();
state.segment_checked = true;
FilterPropagateResult prune_result;
{
lock_guard<mutex> l(stats_lock);
prune_result = filter.CheckStatistics(state.current->stats.statistics);
prune_result = filter.CheckStatistics(current.stats.statistics);
if (prune_result == FilterPropagateResult::NO_PRUNING_POSSIBLE) {
return FilterPropagateResult::NO_PRUNING_POSSIBLE;
}
Expand Down Expand Up @@ -530,8 +540,10 @@ idx_t ColumnData::Fetch(ColumnScanState &state, row_t row_id, Vector &result) {
// perform the fetch within the segment
state.row_index =
start + ((UnsafeNumericCast<idx_t>(row_id) - start) / STANDARD_VECTOR_SIZE * STANDARD_VECTOR_SIZE);
state.current = data.GetSegment(state.lock, state.row_index);
state.internal_index = state.current->start;
auto segment = data.GetSegment(state.row_index);
state.InitializeSegmentTree(data);
state.InitializeSegment(*segment);
state.internal_index = segment->start;
return ScanVector(state, result, STANDARD_VECTOR_SIZE, ScanVectorType::SCAN_FLAT_VECTOR);
}

Expand All @@ -549,7 +561,7 @@ void ColumnData::FetchRow(TransactionData transaction, ColumnFetchState &state,
void ColumnData::Update(TransactionData transaction, idx_t column_index, Vector &update_vector, row_t *row_ids,
idx_t update_count) {
Vector base_vector(type);
ColumnScanState state(data.Lock()); // FIXME: SEGMENT LOCK
ColumnScanState state;
auto fetch_count = Fetch(state, row_ids[0], base_vector);

base_vector.Flatten(fetch_count);
Expand Down Expand Up @@ -654,7 +666,7 @@ unique_ptr<ColumnCheckpointState> ColumnData::Checkpoint(RowGroup &row_group, Co
}

ColumnDataCheckpointer checkpointer(*this, row_group, *checkpoint_state, checkpoint_info);
checkpointer.Checkpoint(nodes);
checkpointer.Checkpoint(data);
checkpointer.FinalizeCheckpoint(data.MoveSegments(l));

// reset the compression function
Expand Down
24 changes: 14 additions & 10 deletions src/storage/table/column_data_checkpointer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,15 @@ ColumnCheckpointState &ColumnDataCheckpointer::GetCheckpointState() {
return state;
}

void ColumnDataCheckpointer::ScanSegments(const column_segment_vector_t &nodes,
void ColumnDataCheckpointer::ScanSegments(const ColumnSegmentTree &tree,
const std::function<void(Vector &, idx_t)> &callback) {
Vector scan_vector(intermediate.GetType(), nullptr);
auto &nodes = tree.ReferenceSegments(state.lock);
for (auto &node : nodes) {
auto &segment = *node.node;
ColumnScanState scan_state(state.lock);
scan_state.current = &segment;
ColumnScanState scan_state;
scan_state.InitializeSegmentTree(tree, state.lock);
scan_state.InitializeSegment(segment);
segment.InitializeScan(scan_state);

for (idx_t base_row_index = 0; base_row_index < segment.count; base_row_index += STANDARD_VECTOR_SIZE) {
Expand Down Expand Up @@ -94,7 +96,7 @@ CompressionType ForceCompression(vector<optional_ptr<CompressionFunction>> &comp
return found ? compression_type : CompressionType::COMPRESSION_AUTO;
}

unique_ptr<AnalyzeState> ColumnDataCheckpointer::DetectBestCompressionMethod(const column_segment_vector_t &nodes,
unique_ptr<AnalyzeState> ColumnDataCheckpointer::DetectBestCompressionMethod(const ColumnSegmentTree &tree,
idx_t &compression_idx) {
D_ASSERT(!compression_functions.empty());
auto &config = DBConfig::GetConfig(GetDatabase());
Expand All @@ -120,7 +122,7 @@ unique_ptr<AnalyzeState> ColumnDataCheckpointer::DetectBestCompressionMethod(con
}

// scan over all the segments and run the analyze step
ScanSegments(nodes, [&](Vector &scan_vector, idx_t count) {
ScanSegments(tree, [&](Vector &scan_vector, idx_t count) {
for (idx_t i = 0; i < compression_functions.size(); i++) {
if (!compression_functions[i]) {
continue;
Expand Down Expand Up @@ -172,13 +174,14 @@ unique_ptr<AnalyzeState> ColumnDataCheckpointer::DetectBestCompressionMethod(con
return state;
}

void ColumnDataCheckpointer::WriteToDisk(const column_segment_vector_t &nodes) {
void ColumnDataCheckpointer::WriteToDisk(const ColumnSegmentTree &tree) {
// there were changes or transient segments
// we need to rewrite the column segments to disk

// first we check the current segments
// if there are any persistent segments, we will mark their old block ids as modified
// since the segments will be rewritten their old on disk data is no longer required
auto &nodes = tree.ReferenceSegments(state.lock);
for (idx_t segment_idx = 0; segment_idx < nodes.size(); segment_idx++) {
auto segment = nodes[segment_idx].node.get();
segment->CommitDropSegment();
Expand All @@ -187,7 +190,7 @@ void ColumnDataCheckpointer::WriteToDisk(const column_segment_vector_t &nodes) {
// now we need to write our segment
// we will first run an analyze step that determines which compression function to use
idx_t compression_idx;
auto analyze_state = DetectBestCompressionMethod(nodes, compression_idx);
auto analyze_state = DetectBestCompressionMethod(tree, compression_idx);

if (!analyze_state) {
throw FatalException("No suitable compression/storage method found to store column");
Expand All @@ -198,7 +201,7 @@ void ColumnDataCheckpointer::WriteToDisk(const column_segment_vector_t &nodes) {
auto compress_state = best_function->init_compression(*this, std::move(analyze_state));

ScanSegments(
nodes, [&](Vector &scan_vector, idx_t count) { best_function->compress(*compress_state, scan_vector, count); });
tree, [&](Vector &scan_vector, idx_t count) { best_function->compress(*compress_state, scan_vector, count); });
best_function->compress_finalize(*compress_state);
}

Expand Down Expand Up @@ -236,12 +239,13 @@ void ColumnDataCheckpointer::WritePersistentSegments(column_segment_vector_t nod
}
}

void ColumnDataCheckpointer::Checkpoint(const column_segment_vector_t &nodes) {
void ColumnDataCheckpointer::Checkpoint(const ColumnSegmentTree &tree) {
auto &nodes = tree.ReferenceSegments(state.lock);
D_ASSERT(!nodes.empty());
has_changes = HasChanges(nodes);
// first check if any of the segments have changes
if (has_changes) {
WriteToDisk(nodes);
WriteToDisk(tree);
}
}

Expand Down
Loading

0 comments on commit f56287e

Please sign in to comment.