diff --git a/include/HitManager.hpp b/include/HitManager.hpp index 6eedd42..24a288e 100644 --- a/include/HitManager.hpp +++ b/include/HitManager.hpp @@ -72,9 +72,10 @@ namespace rapmap { template void intersectSAIntervalWithOutput(SAIntervalHit& h, - RapMapIndexT& rmi, - uint32_t intervalCounter, - SAHitMap& outHits); + RapMapIndexT& rmi, + uint32_t intervalCounter, + SAHitMap& outHits); + template void intersectSAIntervalWithOutput2(SAIntervalHit& h, @@ -93,8 +94,9 @@ namespace rapmap { template SAHitMap intersectSAHits( - std::vector>& inHits, - RapMapIndexT& rmi); + std::vector>& inHits, + RapMapIndexT& rmi, + bool strictFilter=false); template std::vector intersectSAHits2( diff --git a/include/RapMapUtils.hpp b/include/RapMapUtils.hpp index 2d5ef60..6239b4c 100644 --- a/include/RapMapUtils.hpp +++ b/include/RapMapUtils.hpp @@ -369,9 +369,49 @@ namespace rapmap { tqvec.emplace_back(txpPosIn, queryPosIn, queryRCIn); } + /** + * This enforces a more stringent consistency check on + * the hits for this transcript. The hits must be co-linear + * with respect to the query and target. + * + * input: numToCheck --- the number of hits to check in sorted order + * hits after the last of these need not be consistent. + * return: numToCheck if the first numToCheck hits are consistent; + * -1 otherwise + **/ + int32_t checkConsistent(int32_t numToCheck) { + auto numHits = tqvec.size(); + + // special case for only 1 or two hits (common) + if (numHits == 1) { + return numToCheck; + } else if (numHits == 2) { + auto& h1 = (tqvec[0].queryPos < tqvec[1].queryPos) ? tqvec[0] : tqvec[1]; + auto& h2 = (tqvec[0].queryPos < tqvec[1].queryPos) ? tqvec[1] : tqvec[2]; + return (h2.pos > h1.pos) ? (numToCheck) : -1; + } else { + // first, sort by query position + std::sort(tqvec.begin(), tqvec.end(), + [](const SATxpQueryPos& q1, const SATxpQueryPos& q2) -> bool { + return q1.queryPos < q2.queryPos; + }); + + int32_t lastRefPos{std::numeric_limits::min()}; + for (size_t i = 0; i < numToCheck; ++i) { + int32_t refPos = static_cast(tqvec[i].pos); + if (refPos > lastRefPos) { + lastRefPos = refPos; + } else { + return i; + } + } + return numToCheck; + } + } + uint32_t tid; std::vector tqvec; - bool active; + bool active; uint32_t numActive; }; diff --git a/include/SACollector.hpp b/include/SACollector.hpp index af40f8d..261b2ce 100644 --- a/include/SACollector.hpp +++ b/include/SACollector.hpp @@ -16,10 +16,11 @@ class SACollector { SACollector(RapMapIndexT* rmi) : rmi_(rmi) {} bool operator()(std::string& read, - std::vector& hits, - SASearcher& saSearcher, - rapmap::utils::MateStatus mateStatus, - bool strictCheck=false) { + std::vector& hits, + SASearcher& saSearcher, + rapmap::utils::MateStatus mateStatus, + bool strictCheck=false, + bool consistentHits=false) { using QuasiAlignment = rapmap::utils::QuasiAlignment; using MateStatus = rapmap::utils::MateStatus; @@ -483,10 +484,10 @@ class SACollector { auto fwdHitsStart = hits.size(); // If we had > 1 forward hit if (fwdSAInts.size() > 1) { - auto processedHits = rapmap::hit_manager::intersectSAHits(fwdSAInts, *rmi_); - rapmap::hit_manager::collectHitsSimpleSA(processedHits, readLen, maxDist, hits, mateStatus); + auto processedHits = rapmap::hit_manager::intersectSAHits(fwdSAInts, *rmi_, consistentHits); + rapmap::hit_manager::collectHitsSimpleSA(processedHits, readLen, maxDist, hits, mateStatus); } else if (fwdSAInts.size() == 1) { // only 1 hit! - auto& saIntervalHit = fwdSAInts.front(); + auto& saIntervalHit = fwdSAInts.front(); auto initialSize = hits.size(); for (OffsetT i = saIntervalHit.begin; i != saIntervalHit.end; ++i) { auto globalPos = SA[i]; @@ -520,7 +521,7 @@ class SACollector { auto rcHitsStart = fwdHitsEnd; // If we had > 1 rc hit if (rcSAInts.size() > 1) { - auto processedHits = rapmap::hit_manager::intersectSAHits(rcSAInts, *rmi_); + auto processedHits = rapmap::hit_manager::intersectSAHits(rcSAInts, *rmi_, consistentHits); rapmap::hit_manager::collectHitsSimpleSA(processedHits, readLen, maxDist, hits, mateStatus); } else if (rcSAInts.size() == 1) { // only 1 hit! auto& saIntervalHit = rcSAInts.front(); diff --git a/src/HitManager.cpp b/src/HitManager.cpp index fd37871..ba1cb67 100644 --- a/src/HitManager.cpp +++ b/src/HitManager.cpp @@ -346,9 +346,9 @@ namespace rapmap { template void intersectSAIntervalWithOutput(SAIntervalHit& h, - RapMapIndexT& rmi, - uint32_t intervalCounter, - SAHitMap& outHits) { + RapMapIndexT& rmi, + uint32_t intervalCounter, + SAHitMap& outHits) { using OffsetT = typename RapMapIndexT::IndexType; // Convenient bindings for variables we'll use auto& SA = rmi.SA; @@ -577,7 +577,8 @@ namespace rapmap { template SAHitMap intersectSAHits( std::vector>& inHits, - RapMapIndexT& rmi + RapMapIndexT& rmi, + bool strictFilter ) { using OffsetT = typename RapMapIndexT::IndexType; // Each inHit is a SAIntervalHit structure that contains @@ -592,7 +593,7 @@ namespace rapmap { // with less than 2 hits. SAHitMap outHits; if (inHits.size() < 2) { - std::cerr << "intersectHitsSA() called with < 2 k-mer " + std::cerr << "intersectHitsSA() called with < 2 hits " " hits; this shouldn't happen\n"; return outHits; } @@ -626,20 +627,21 @@ namespace rapmap { // Now intersect everything in inHits (apart from minHits) // to get the final set of mapping info. - size_t intervalCounter{2}; + size_t intervalCounter{2}; for (auto& h : inHits) { if (&h != minHit) { // don't intersect minHit with itself intersectSAIntervalWithOutput(h, rmi, intervalCounter, outHits); - ++intervalCounter; + ++intervalCounter; } } size_t requiredNumHits = inHits.size(); // Mark as active any transcripts with the required number of hits. for (auto it = outHits.begin(); it != outHits.end(); ++it) { - if (it->second.numActive >= requiredNumHits) { - it->second.active = true; - } + bool enoughHits = (it->second.numActive >= requiredNumHits); + it->second.active = (strictFilter) ? + (enoughHits and it->second.checkConsistent(requiredNumHits)) : + (enoughHits); } return outHits; } @@ -657,36 +659,42 @@ namespace rapmap { template void intersectSAIntervalWithOutput(SAIntervalHit& h, - SAIndex32BitDense& rmi, uint32_t intervalCounter, SAHitMap& outHits); + SAIndex32BitDense& rmi, + uint32_t intervalCounter, + SAHitMap& outHits); template void intersectSAIntervalWithOutput(SAIntervalHit& h, - SAIndex64BitDense& rmi, uint32_t intervalCounter, SAHitMap& outHits); + SAIndex64BitDense& rmi, + uint32_t intervalCounter, + SAHitMap& outHits); template SAHitMap intersectSAHits(std::vector>& inHits, - SAIndex32BitDense& rmi); + SAIndex32BitDense& rmi, bool strictFilter); template SAHitMap intersectSAHits(std::vector>& inHits, - SAIndex64BitDense& rmi); + SAIndex64BitDense& rmi, bool strictFilter); template void intersectSAIntervalWithOutput(SAIntervalHit& h, - SAIndex32BitPerfect& rmi, uint32_t intervalCounter, SAHitMap& outHits); + SAIndex32BitPerfect& rmi, + uint32_t intervalCounter, + SAHitMap& outHits); template void intersectSAIntervalWithOutput(SAIntervalHit& h, - SAIndex64BitPerfect& rmi, uint32_t intervalCounter, SAHitMap& outHits); + SAIndex64BitPerfect& rmi, + uint32_t intervalCounter, + SAHitMap& outHits); template SAHitMap intersectSAHits(std::vector>& inHits, - SAIndex32BitPerfect& rmi); + SAIndex32BitPerfect& rmi, bool strictFilter); template SAHitMap intersectSAHits(std::vector>& inHits, - SAIndex64BitPerfect& rmi); - - + SAIndex64BitPerfect& rmi, bool strictFilter); } } diff --git a/src/RapMapSAMapper.cpp b/src/RapMapSAMapper.cpp index 4560c96..d0888d4 100644 --- a/src/RapMapSAMapper.cpp +++ b/src/RapMapSAMapper.cpp @@ -94,20 +94,20 @@ using FixedWriter = rapmap::utils::FixedWriter; template void processReadsSingleSA(single_parser * parser, - RapMapIndexT& rmi, - CollectorT& hitCollector, - MutexT* iomutex, - std::shared_ptr outQueue, - HitCounters& hctr, - uint32_t maxNumHits, - bool noOutput, - bool strictCheck) { + RapMapIndexT& rmi, + CollectorT& hitCollector, + MutexT* iomutex, + std::shared_ptr outQueue, + HitCounters& hctr, + uint32_t maxNumHits, + bool noOutput, + bool strictCheck, + bool consistentHits) { using OffsetT = typename RapMapIndexT::IndexType; auto& txpNames = rmi.txpNames; auto& txpLens = rmi.txpLens; uint32_t n{0}; - constexpr char bases[] = {'A', 'C', 'G', 'T'}; auto logger = spdlog::get("stderrLog"); @@ -131,7 +131,7 @@ void processReadsSingleSA(single_parser * parser, readLen = j->data[i].seq.length(); ++hctr.numReads; hits.clear(); - hitCollector(j->data[i].seq, hits, saSearcher, MateStatus::SINGLE_END, strictCheck); + hitCollector(j->data[i].seq, hits, saSearcher, MateStatus::SINGLE_END, strictCheck, consistentHits); auto numHits = hits.size(); hctr.totHits += numHits; @@ -196,22 +196,22 @@ void processReadsSingleSA(single_parser * parser, */ template void processReadsPairSA(paired_parser* parser, - RapMapIndexT& rmi, - CollectorT& hitCollector, - MutexT* iomutex, - std::shared_ptr outQueue, - HitCounters& hctr, - uint32_t maxNumHits, - bool noOutput, - bool strictCheck, - bool nonStrictMerge) { + RapMapIndexT& rmi, + CollectorT& hitCollector, + MutexT* iomutex, + std::shared_ptr outQueue, + HitCounters& hctr, + uint32_t maxNumHits, + bool noOutput, + bool strictCheck, + bool nonStrictMerge, + bool consistentHits) { using OffsetT = typename RapMapIndexT::IndexType; auto& txpNames = rmi.txpNames; auto& txpLens = rmi.txpLens; uint32_t n{0}; - constexpr char bases[] = {'A', 'C', 'G', 'T'}; auto logger = spdlog::get("stderrLog"); @@ -243,13 +243,16 @@ void processReadsPairSA(paired_parser* parser, rightHits.clear(); bool lh = hitCollector(j->data[i].first.seq, - leftHits, saSearcher, - MateStatus::PAIRED_END_LEFT, - strictCheck); + leftHits, saSearcher, + MateStatus::PAIRED_END_LEFT, + strictCheck, + consistentHits); + bool rh = hitCollector(j->data[i].second.seq, - rightHits, saSearcher, - MateStatus::PAIRED_END_RIGHT, - strictCheck); + rightHits, saSearcher, + MateStatus::PAIRED_END_RIGHT, + strictCheck, + consistentHits); if (nonStrictMerge) { rapmap::utils::mergeLeftRightHitsFuzzy( @@ -310,31 +313,33 @@ void processReadsPairSA(paired_parser* parser, template bool spawnProcessReadsThreads( - uint32_t nthread, - paired_parser* parser, - RapMapIndexT& rmi, - MutexT& iomutex, - std::shared_ptr outQueue, - HitCounters& hctr, - uint32_t maxNumHits, - bool noOutput, - bool strictCheck, - bool fuzzy) { + uint32_t nthread, + paired_parser* parser, + RapMapIndexT& rmi, + MutexT& iomutex, + std::shared_ptr outQueue, + HitCounters& hctr, + uint32_t maxNumHits, + bool noOutput, + bool strictCheck, + bool fuzzy, + bool consistentHits) { std::vector threads; SACollector saCollector(&rmi); for (size_t i = 0; i < nthread; ++i) { threads.emplace_back(processReadsPairSA, MutexT>, - parser, - std::ref(rmi), - std::ref(saCollector), - &iomutex, - outQueue, - std::ref(hctr), - maxNumHits, - noOutput, - strictCheck, - fuzzy); + parser, + std::ref(rmi), + std::ref(saCollector), + &iomutex, + outQueue, + std::ref(hctr), + maxNumHits, + noOutput, + strictCheck, + fuzzy, + consistentHits); } for (auto& t : threads) { t.join(); } @@ -343,29 +348,31 @@ bool spawnProcessReadsThreads( template bool spawnProcessReadsThreads( - uint32_t nthread, - single_parser* parser, - RapMapIndexT& rmi, - MutexT& iomutex, - std::shared_ptr outQueue, - HitCounters& hctr, - uint32_t maxNumHits, - bool noOutput, - bool strictCheck) { + uint32_t nthread, + single_parser* parser, + RapMapIndexT& rmi, + MutexT& iomutex, + std::shared_ptr outQueue, + HitCounters& hctr, + uint32_t maxNumHits, + bool noOutput, + bool strictCheck, + bool consistentHits) { std::vector threads; SACollector saCollector(&rmi); for (size_t i = 0; i < nthread; ++i) { threads.emplace_back(processReadsSingleSA, MutexT>, - parser, - std::ref(rmi), - std::ref(saCollector), - &iomutex, - outQueue, - std::ref(hctr), - maxNumHits, - noOutput, - strictCheck); + parser, + std::ref(rmi), + std::ref(saCollector), + &iomutex, + outQueue, + std::ref(hctr), + maxNumHits, + noOutput, + strictCheck, + consistentHits); } for (auto& t : threads) { t.join(); } return true; @@ -383,7 +390,8 @@ bool mapReads(RapMapIndexT& rmi, TCLAP::ValueArg& outname, TCLAP::SwitchArg& noout, TCLAP::SwitchArg& strict, - TCLAP::SwitchArg& fuzzy) { + TCLAP::SwitchArg& fuzzy, + TCLAP::SwitchArg& consistent) { std::cerr << "\n\n\n\n"; @@ -421,6 +429,7 @@ bool mapReads(RapMapIndexT& rmi, bool strictCheck = strict.getValue(); bool fuzzyIntersection = fuzzy.getValue(); + bool consistentHits = consistent.getValue(); SpinLockT iomutex; { ScopedTimer timer; @@ -449,7 +458,8 @@ bool mapReads(RapMapIndexT& rmi, pairFileList, pairFileList+numFiles)); spawnProcessReadsThreads(nthread, pairParserPtr.get(), rmi, iomutex, - outLog, hctrs, maxNumHits.getValue(), noout.getValue(), strictCheck, fuzzyIntersection); + outLog, hctrs, maxNumHits.getValue(), noout.getValue(), strictCheck, + fuzzyIntersection, consistentHits); delete [] pairFileList; } else { std::vector unmatedReadVec = rapmap::utils::tokenize(unmatedReads.getValue(), ','); @@ -464,7 +474,8 @@ bool mapReads(RapMapIndexT& rmi, /** Create the threads depending on the collector type **/ spawnProcessReadsThreads(nthread, singleParserPtr.get(), rmi, iomutex, - outLog, hctrs, maxNumHits.getValue(), noout.getValue(), strictCheck); + outLog, hctrs, maxNumHits.getValue(), noout.getValue(), + strictCheck, consistentHits); } std::cerr << "\n\n"; @@ -508,6 +519,7 @@ int rapMapSAMap(int argc, char* argv[]) { TCLAP::SwitchArg noout("n", "noOutput", "Don't write out any alignments (for speed testing purposes)", false); TCLAP::SwitchArg strict("s", "strictCheck", "Perform extra checks to try and assure that only equally \"best\" mappings for a read are reported", false); TCLAP::SwitchArg fuzzy("f", "fuzzyIntersection", "Find paired-end mapping locations using fuzzy intersection", false); + TCLAP::SwitchArg consistent("c", "consistentHits", "Ensure that the hits collected are consistent (co-linear)", false); cmd.add(index); cmd.add(noout); @@ -519,6 +531,7 @@ int rapMapSAMap(int argc, char* argv[]) { cmd.add(maxNumHits); cmd.add(strict); cmd.add(fuzzy); + cmd.add(consistent); auto consoleSink = std::make_shared(); auto consoleLog = spdlog::create("stderrLog", {consoleSink}); @@ -576,44 +589,43 @@ int rapMapSAMap(int argc, char* argv[]) { bool success{false}; if (h.bigSA()) { - std::cerr << "Loading 64-bit suffix array index: \n"; + //std::cerr << "Loading 64-bit suffix array index: \n"; //BigSAIdxPtr.reset(new RapMapSAIndex); //BigSAIdxPtr->load(indexPrefix, h.kmerLen()); if (h.perfectHash()) { - RapMapSAIndex>>> rmi; - rmi.load(indexPrefix); - success = mapReads(rmi, consoleLog, index, read1, read2, - unmatedReads, numThreads, maxNumHits, - outname, noout, strict, fuzzy); + RapMapSAIndex>>> rmi; + rmi.load(indexPrefix); + success = mapReads(rmi, consoleLog, index, read1, read2, + unmatedReads, numThreads, maxNumHits, + outname, noout, strict, fuzzy, consistent); } else { - RapMapSAIndex, - rapmap::utils::KmerKeyHasher>> rmi; - rmi.load(indexPrefix); - success = mapReads(rmi, consoleLog, index, read1, read2, - unmatedReads, numThreads, maxNumHits, - outname, noout, strict, fuzzy); + RapMapSAIndex, + rapmap::utils::KmerKeyHasher>> rmi; + rmi.load(indexPrefix); + success = mapReads(rmi, consoleLog, index, read1, read2, + unmatedReads, numThreads, maxNumHits, + outname, noout, strict, fuzzy, consistent); } - } else { - std::cerr << "Loading 32-bit suffix array index: \n"; + //std::cerr << "Loading 32-bit suffix array index: \n"; //SAIdxPtr.reset(new RapMapSAIndex); //SAIdxPtr->load(indexPrefix, h.kmerLen()); - if (h.perfectHash()) { - RapMapSAIndex>>> rmi; - rmi.load(indexPrefix); - success = mapReads(rmi, consoleLog, index, read1, read2, - unmatedReads, numThreads, maxNumHits, - outname, noout, strict, fuzzy); - } else { - RapMapSAIndex, - rapmap::utils::KmerKeyHasher>> rmi; - rmi.load(indexPrefix); - success = mapReads(rmi, consoleLog, index, read1, read2, - unmatedReads, numThreads, maxNumHits, - outname, noout, strict, fuzzy); - } + if (h.perfectHash()) { + RapMapSAIndex>>> rmi; + rmi.load(indexPrefix); + success = mapReads(rmi, consoleLog, index, read1, read2, + unmatedReads, numThreads, maxNumHits, + outname, noout, strict, fuzzy, consistent); + } else { + RapMapSAIndex, + rapmap::utils::KmerKeyHasher>> rmi; + rmi.load(indexPrefix); + success = mapReads(rmi, consoleLog, index, read1, read2, + unmatedReads, numThreads, maxNumHits, + outname, noout, strict, fuzzy, consistent); + } } return success ? 0 : 1;