Skip to content

Commit

Permalink
Merge pull request #285 from waveygang/parallel-filter
Browse files Browse the repository at this point in the history
Parallel filter
  • Loading branch information
ekg authored Oct 23, 2024
2 parents eef9040 + 47a34c3 commit 8672261
Showing 1 changed file with 116 additions and 28 deletions.
144 changes: 116 additions & 28 deletions src/map/include/computeMap.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace fs = std::filesystem;
#include <thread>
#include <condition_variable>
#include <mutex>
#include <sstream>
#include "common/atomic_queue/atomic_queue.h"

//Own includes
Expand Down Expand Up @@ -140,7 +141,11 @@ namespace skch
// Atomic queues for input and output
typedef atomic_queue::AtomicQueue<InputSeqProgContainer*, 1024, nullptr, true, true, false, false> input_atomic_queue_t;
typedef atomic_queue::AtomicQueue<QueryMappingOutput*, 1024, nullptr, true, true, false, false> merged_mappings_queue_t;
typedef atomic_queue::AtomicQueue<MapModuleOutput*, 1024, nullptr, true, true, false, false> output_atomic_queue_t;
typedef atomic_queue::AtomicQueue<std::pair<seqno_t, MappingResultsVector_t*>*, 1024> aggregate_atomic_queue_t;
typedef atomic_queue::AtomicQueue<std::string*, 1024> writer_atomic_queue_t;
typedef atomic_queue::AtomicQueue<QueryMappingOutput*, 1024, nullptr, true, true, false, false> query_output_atomic_queue_t;
typedef atomic_queue::AtomicQueue<FragmentData*, 8192, nullptr, true, true, false, false> fragment_atomic_queue_t;


void processFragment(FragmentData* fragment,
std::vector<IntervalPoint>& intervalPoints,
Expand Down Expand Up @@ -354,9 +359,6 @@ namespace skch
reader_done.store(true);
}

typedef atomic_queue::AtomicQueue<QueryMappingOutput*, 1024, nullptr, true, true, false, false> query_output_atomic_queue_t;
typedef atomic_queue::AtomicQueue<FragmentData*, 8192, nullptr, true, true, false, false> fragment_atomic_queue_t;

void worker_thread(input_atomic_queue_t& input_queue,
fragment_atomic_queue_t& fragment_queue,
merged_mappings_queue_t& merged_queue,
Expand Down Expand Up @@ -399,14 +401,13 @@ namespace skch
reportReadMappings(output->results, output->queryName, outstrm);
}
delete output;
} else if (workers_done.load() && output_queue.was_empty()) {
++wait_count;
if (wait_count < 5) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} else {
} else {
if (workers_done.load() && output_queue.was_empty()) {
++wait_count;
}
if (wait_count > 10) {
break;
}
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
Expand Down Expand Up @@ -450,9 +451,7 @@ namespace skch
input_atomic_queue_t input_queue;
merged_mappings_queue_t merged_queue;
fragment_atomic_queue_t fragment_queue;
std::atomic<bool> reader_done(false);
std::atomic<bool> workers_done(false);
std::atomic<bool> fragments_done(false);
writer_atomic_queue_t writer_queue;

this->querySequenceNames = idManager->getQuerySequenceNames();
this->targetSequenceNames = idManager->getTargetSequenceNames();
Expand Down Expand Up @@ -505,6 +504,9 @@ namespace skch
std::cerr << "[mashmap::skch::Map::mapQuery] Building index for subset " << subset_count << " with " << target_subset.size() << " sequences" << std::endl;
refSketch = new skch::Sketch(param, *idManager, target_subset);
}
std::atomic<bool> reader_done(false);
std::atomic<bool> workers_done(false);
std::atomic<bool> fragments_done(false);
processSubset(subset_count, target_subsets.size(), total_seq_length, input_queue, merged_queue,
fragment_queue, reader_done, workers_done, fragments_done, combinedMappings);
}
Expand All @@ -525,21 +527,51 @@ namespace skch
}

// Process combined mappings
std::atomic<bool> processing_done(false);
std::atomic<bool> workers_done(false);
std::atomic<bool> output_done(false);
aggregate_atomic_queue_t aggregate_queue;

// Initialize progress logger
progress_meter::ProgressMeter progress(
combinedMappings.size(),
"[mashmap::skch::Map::mapQuery] filtering");

// Start worker threads
std::vector<std::thread> workers;
for (int i = 0; i < param.threads; ++i) {
workers.emplace_back(&Map::processCombinedMappingsThread, this, std::ref(aggregate_queue), std::ref(writer_queue), std::ref(processing_done));
}

// Start output thread
std::thread output_thread(&Map::outputThread, this, std::ref(outstrm), std::ref(writer_queue), std::ref(processing_done), std::ref(workers_done), std::ref(output_done), std::ref(progress));

// Enqueue tasks
for (auto& [querySeqId, mappings] : combinedMappings) {
// Sort mappings by query position, then reference sequence id, then reference position
std::sort(
mappings.begin(), mappings.end(),
[](const MappingResult &a, const MappingResult &b) {
return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos, a.strand)
< std::tie(b.queryStartPos, b.refSeqId, b.refStartPos, b.strand);
}
);
auto* task = new std::pair<seqno_t, MappingResultsVector_t*>(querySeqId, &mappings);
while (!aggregate_queue.try_push(task)) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}

// Signal that all tasks have been enqueued
processing_done.store(true);

std::string queryName = idManager->getSequenceName(querySeqId);
processAggregatedMappings(queryName, mappings, outstrm);
totalReadsMapped += !mappings.empty();
// Wait for worker threads to finish
for (auto& worker : workers) {
worker.join();
}

workers_done.store(true);

// Wait for output thread to finish
while (!output_done.load()) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
output_thread.join();

progress.finish();

std::cerr << "[mashmap::skch::Map::mapQuery] "
<< "count of mapped reads = " << totalReadsMapped
<< ", reads qualified for mapping = " << totalReadsPickedForMapping
Expand All @@ -555,7 +587,7 @@ namespace skch
{
progress_meter::ProgressMeter progress(
total_seq_length,
"[mashmap::skch::Map::mapQuery] mapped ("
"[mashmap::skch::Map::mapQuery] mapping ("
+ std::to_string(subset_count + 1) + "/" + std::to_string(total_subsets) + ")");

// Launch reader thread
Expand Down Expand Up @@ -838,7 +870,7 @@ namespace skch
}
}

void processAggregatedMappings(const std::string& queryName, MappingResultsVector_t& mappings, std::ofstream& outstrm) {
void processAggregatedMappings(const std::string& queryName, MappingResultsVector_t& mappings) {

// XXX we should fix this combined condition
if (param.mergeMappings && param.split) {
Expand Down Expand Up @@ -871,7 +903,7 @@ namespace skch
mappings = std::move(filteredMappings);
}

reportReadMappings(mappings, queryName, outstrm);
// Removed reportReadMappings call
}

void aggregator_thread(merged_mappings_queue_t& merged_queue,
Expand Down Expand Up @@ -2107,7 +2139,7 @@ namespace skch
* @param[in] outstrm file output stream object
*/
void reportReadMappings(MappingResultsVector_t &readMappings, const std::string &queryName,
std::ofstream &outstrm)
std::ostream &outstrm)
{
//Print the results
for(auto &e : readMappings)
Expand Down Expand Up @@ -2155,6 +2187,62 @@ namespace skch
}
}

private:
void processCombinedMappingsThread(aggregate_atomic_queue_t& aggregate_queue, writer_atomic_queue_t& writer_queue, std::atomic<bool>& processing_done) {
int wait_count = 0;
while (true) {
std::pair<seqno_t, MappingResultsVector_t*>* task = nullptr;
if (aggregate_queue.try_pop(task)) {
wait_count = 0;
auto querySeqId = task->first;
auto& mappings = *(task->second);

std::string queryName = idManager->getSequenceName(querySeqId);
processAggregatedMappings(queryName, mappings);

std::stringstream ss;
reportReadMappings(mappings, queryName, ss);

auto* output = new std::string(ss.str());
while (!writer_queue.try_push(output)) {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
delete task;
} else {
if (processing_done.load() && aggregate_queue.was_empty()) {
++wait_count;
}
if (wait_count > 10) {
break;
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}

void outputThread(std::ofstream& outstrm, writer_atomic_queue_t& writer_queue, std::atomic<bool>& processing_done,
std::atomic<bool>& workers_done, std::atomic<bool>& output_done, progress_meter::ProgressMeter& progress) {
int wait_count = 0;
while (!output_done.load()) {
std::string* result = nullptr;
if (writer_queue.try_pop(result)) {
wait_count = 0;
outstrm << *result;
delete result;
// Increment progress
progress.increment(1);
} else {
if (processing_done.load() && workers_done.load() && writer_queue.was_empty()) {
++wait_count;
}
if (wait_count > 10) {
output_done.store(true);
}
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
}

public:

/**
Expand Down

0 comments on commit 8672261

Please sign in to comment.