diff --git a/src/map/include/computeMap.hpp b/src/map/include/computeMap.hpp index 95b64c8b..21258edd 100644 --- a/src/map/include/computeMap.hpp +++ b/src/map/include/computeMap.hpp @@ -26,6 +26,7 @@ namespace fs = std::filesystem; #include #include #include +#include #include "common/atomic_queue/atomic_queue.h" //Own includes @@ -140,7 +141,11 @@ namespace skch // Atomic queues for input and output typedef atomic_queue::AtomicQueue input_atomic_queue_t; typedef atomic_queue::AtomicQueue merged_mappings_queue_t; - typedef atomic_queue::AtomicQueue output_atomic_queue_t; + typedef atomic_queue::AtomicQueue*, 1024> aggregate_atomic_queue_t; + typedef atomic_queue::AtomicQueue writer_atomic_queue_t; + typedef atomic_queue::AtomicQueue query_output_atomic_queue_t; + typedef atomic_queue::AtomicQueue fragment_atomic_queue_t; + void processFragment(FragmentData* fragment, std::vector& intervalPoints, @@ -354,9 +359,6 @@ namespace skch reader_done.store(true); } - typedef atomic_queue::AtomicQueue query_output_atomic_queue_t; - typedef atomic_queue::AtomicQueue fragment_atomic_queue_t; - void worker_thread(input_atomic_queue_t& input_queue, fragment_atomic_queue_t& fragment_queue, merged_mappings_queue_t& merged_queue, @@ -399,14 +401,13 @@ namespace skch reportReadMappings(output->results, output->queryName, outstrm); } delete output; - } else if (workers_done.load() && output_queue.was_empty()) { - ++wait_count; - if (wait_count < 5) { - std::this_thread::sleep_for(std::chrono::milliseconds(10)); - } else { + } else { + if (workers_done.load() && output_queue.was_empty()) { + ++wait_count; + } + if (wait_count > 10) { break; } - } else { std::this_thread::sleep_for(std::chrono::milliseconds(10)); } } @@ -450,9 +451,7 @@ namespace skch input_atomic_queue_t input_queue; merged_mappings_queue_t merged_queue; fragment_atomic_queue_t fragment_queue; - std::atomic reader_done(false); - std::atomic workers_done(false); - std::atomic fragments_done(false); + writer_atomic_queue_t writer_queue; this->querySequenceNames = idManager->getQuerySequenceNames(); this->targetSequenceNames = idManager->getTargetSequenceNames(); @@ -505,6 +504,9 @@ namespace skch std::cerr << "[mashmap::skch::Map::mapQuery] Building index for subset " << subset_count << " with " << target_subset.size() << " sequences" << std::endl; refSketch = new skch::Sketch(param, *idManager, target_subset); } + std::atomic reader_done(false); + std::atomic workers_done(false); + std::atomic fragments_done(false); processSubset(subset_count, target_subsets.size(), total_seq_length, input_queue, merged_queue, fragment_queue, reader_done, workers_done, fragments_done, combinedMappings); } @@ -525,21 +527,51 @@ namespace skch } // Process combined mappings + std::atomic processing_done(false); + std::atomic workers_done(false); + std::atomic output_done(false); + aggregate_atomic_queue_t aggregate_queue; + + // Initialize progress logger + progress_meter::ProgressMeter progress( + combinedMappings.size(), + "[mashmap::skch::Map::mapQuery] filtering"); + + // Start worker threads + std::vector workers; + for (int i = 0; i < param.threads; ++i) { + workers.emplace_back(&Map::processCombinedMappingsThread, this, std::ref(aggregate_queue), std::ref(writer_queue), std::ref(processing_done)); + } + + // Start output thread + std::thread output_thread(&Map::outputThread, this, std::ref(outstrm), std::ref(writer_queue), std::ref(processing_done), std::ref(workers_done), std::ref(output_done), std::ref(progress)); + + // Enqueue tasks for (auto& [querySeqId, mappings] : combinedMappings) { - // Sort mappings by query position, then reference sequence id, then reference position - std::sort( - mappings.begin(), mappings.end(), - [](const MappingResult &a, const MappingResult &b) { - return std::tie(a.queryStartPos, a.refSeqId, a.refStartPos, a.strand) - < std::tie(b.queryStartPos, b.refSeqId, b.refStartPos, b.strand); - } - ); + auto* task = new std::pair(querySeqId, &mappings); + while (!aggregate_queue.try_push(task)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + + // Signal that all tasks have been enqueued + processing_done.store(true); - std::string queryName = idManager->getSequenceName(querySeqId); - processAggregatedMappings(queryName, mappings, outstrm); - totalReadsMapped += !mappings.empty(); + // Wait for worker threads to finish + for (auto& worker : workers) { + worker.join(); } + workers_done.store(true); + + // Wait for output thread to finish + while (!output_done.load()) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + output_thread.join(); + + progress.finish(); + std::cerr << "[mashmap::skch::Map::mapQuery] " << "count of mapped reads = " << totalReadsMapped << ", reads qualified for mapping = " << totalReadsPickedForMapping @@ -555,7 +587,7 @@ namespace skch { progress_meter::ProgressMeter progress( total_seq_length, - "[mashmap::skch::Map::mapQuery] mapped (" + "[mashmap::skch::Map::mapQuery] mapping (" + std::to_string(subset_count + 1) + "/" + std::to_string(total_subsets) + ")"); // Launch reader thread @@ -838,7 +870,7 @@ namespace skch } } - void processAggregatedMappings(const std::string& queryName, MappingResultsVector_t& mappings, std::ofstream& outstrm) { + void processAggregatedMappings(const std::string& queryName, MappingResultsVector_t& mappings) { // XXX we should fix this combined condition if (param.mergeMappings && param.split) { @@ -871,7 +903,7 @@ namespace skch mappings = std::move(filteredMappings); } - reportReadMappings(mappings, queryName, outstrm); + // Removed reportReadMappings call } void aggregator_thread(merged_mappings_queue_t& merged_queue, @@ -2107,7 +2139,7 @@ namespace skch * @param[in] outstrm file output stream object */ void reportReadMappings(MappingResultsVector_t &readMappings, const std::string &queryName, - std::ofstream &outstrm) + std::ostream &outstrm) { //Print the results for(auto &e : readMappings) @@ -2155,6 +2187,62 @@ namespace skch } } + private: + void processCombinedMappingsThread(aggregate_atomic_queue_t& aggregate_queue, writer_atomic_queue_t& writer_queue, std::atomic& processing_done) { + int wait_count = 0; + while (true) { + std::pair* task = nullptr; + if (aggregate_queue.try_pop(task)) { + wait_count = 0; + auto querySeqId = task->first; + auto& mappings = *(task->second); + + std::string queryName = idManager->getSequenceName(querySeqId); + processAggregatedMappings(queryName, mappings); + + std::stringstream ss; + reportReadMappings(mappings, queryName, ss); + + auto* output = new std::string(ss.str()); + while (!writer_queue.try_push(output)) { + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + delete task; + } else { + if (processing_done.load() && aggregate_queue.was_empty()) { + ++wait_count; + } + if (wait_count > 10) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + } + + void outputThread(std::ofstream& outstrm, writer_atomic_queue_t& writer_queue, std::atomic& processing_done, + std::atomic& workers_done, std::atomic& output_done, progress_meter::ProgressMeter& progress) { + int wait_count = 0; + while (!output_done.load()) { + std::string* result = nullptr; + if (writer_queue.try_pop(result)) { + wait_count = 0; + outstrm << *result; + delete result; + // Increment progress + progress.increment(1); + } else { + if (processing_done.load() && workers_done.load() && writer_queue.was_empty()) { + ++wait_count; + } + if (wait_count > 10) { + output_done.store(true); + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + } + } + public: /**