From c23adcde4ef83402e49634de3c544c8be93d0ef0 Mon Sep 17 00:00:00 2001 From: hiraksarkar Date: Mon, 27 Apr 2020 23:24:47 -0400 Subject: [PATCH 01/52] dimping equivalence classes --- include/GZipWriter.hpp | 6 +++++ src/CollapsedCellOptimizer.cpp | 14 ++++++++++ src/GZipWriter.cpp | 49 ++++++++++++++++++++++++++++++++++ 3 files changed, 69 insertions(+) diff --git a/include/GZipWriter.hpp b/include/GZipWriter.hpp index 94629b71a..8ea847aca 100644 --- a/include/GZipWriter.hpp +++ b/include/GZipWriter.hpp @@ -92,6 +92,11 @@ class GZipWriter { bool writeCellEQVec(size_t barcode, const std::vector& offsets, const std::vector& counts, bool quiet = true); + bool writeDedupCellEQVec(size_t barcode, + const std::vector>& labels, + const std::vector& counts, + bool quiet = true); + bool writeUmiGraph(alevin::graph::Graph& g, std::string& trueBarcodeStr); bool setSamplingPath(const SalmonOpts& sopt); @@ -112,6 +117,7 @@ class GZipWriter { std::unique_ptr tierMatrixStream_{nullptr}; std::unique_ptr umiGraphStream_{nullptr}; std::unique_ptr cellEQStream_{nullptr}; + std::unique_ptr cellDedupEQStream_{nullptr}; std::unique_ptr arboMatrixStream_{nullptr}; std::unique_ptr bcNameStream_{nullptr}; std::unique_ptr bcFeaturesStream_{nullptr}; diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index bbdefc781..d08ff791e 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -449,6 +449,20 @@ void optimizeCell(std::vector& trueBarcodes, std::exit(74); } + + bool writeDedupEqClasses{true}; + if ( writeDedupEqClasses ){ + std::vector> labelsEq ; + std::vector countsEq ; + labelsEq.resize(salmonEqclasses.size()); + countsEq.resize(salmonEqclasses.size()); + for(size_t salEqId = 0 ; salEqId < salmonEqclasses.size(); ++salEqId){ + labelsEq[salEqId] = salmonEqclasses[salEqId].labels; + countsEq[salEqId] = salmonEqclasses[salEqId].count; + } + gzw.writeDedupCellEQVec(trueBarcodeIdx, labelsEq, countsEq, true); + } + if ( numBootstraps and noEM ) { jointlog->error("Cannot perform bootstrapping with noEM"); jointlog->flush(); diff --git a/src/GZipWriter.cpp b/src/GZipWriter.cpp index 1f696c497..85c604437 100644 --- a/src/GZipWriter.cpp +++ b/src/GZipWriter.cpp @@ -30,6 +30,9 @@ GZipWriter::~GZipWriter() { if (cellEQStream_){ cellEQStream_->reset(); } + if (cellDedupEQStream_){ + cellDedupEQStream_->reset(); + } if (umiGraphStream_){ umiGraphStream_->reset(); } @@ -69,6 +72,9 @@ void GZipWriter::close_all_streams(){ if (cellEQStream_){ cellEQStream_->reset(); } + if (cellDedupEQStream_){ + cellDedupEQStream_->reset(); + } if (umiGraphStream_){ umiGraphStream_->reset(); } @@ -1481,6 +1487,49 @@ bool GZipWriter::writeCellEQVec(size_t barcode, const std::vector& off return true; } +// TODO: complete this method +bool GZipWriter::writeDedupCellEQVec( + size_t barcode, + const std::vector>& labels, + const std::vector& counts, + bool quiet + ) { +#if defined __APPLE__ + spin_lock::scoped_lock sl(writeMutex_); +#else + std::lock_guard lock(writeMutex_); +#endif + if (!cellDedupEQStream_) { + cellDedupEQStream_.reset(new boost::iostreams::filtering_ostream); + cellDedupEQStream_->push(boost::iostreams::gzip_compressor(6)); + auto ceqFilename = path_ / "alevin" / "cell_dedup_eq_mat.gz"; + cellDedupEQStream_->push(boost::iostreams::file_sink(ceqFilename.string(), + std::ios_base::out | std::ios_base::binary)); + } + + boost::iostreams::filtering_ostream& ofile = *cellDedupEQStream_; + size_t num = labels.size(); + size_t elSize = sizeof(typename std::vector::value_type); + // write the barcode + ofile.write(reinterpret_cast(&barcode), sizeof(barcode)); + // write the number of elements in the list + ofile.write(reinterpret_cast(&num), sizeof(barcode)); + // write the individual labels + for (size_t i = 0 ; i < num ; ++i) { + size_t labelLength = labels[i].size(); + // write the length of the equivalence class vector + ofile.write(reinterpret_cast(&labelLength), sizeof(barcode)); + // write the label vector + ofile.write(reinterpret_cast(const_cast(labels[i].data())), elSize * labelLength); + } + ofile.write(reinterpret_cast(const_cast(counts.data())), elSize * num); + // write the offsets and counts + if (!quiet) { + logger_->info("wrote EQ vector for barcode ID {}", barcode); + } + return true; +} + bool GZipWriter::writeUmiGraph(alevin::graph::Graph& g, std::string& cbString) { #if defined __APPLE__ From c4669a1c2b6c3c1f43a3bdf5aba23098865db98f Mon Sep 17 00:00:00 2001 From: hiraksarkar Date: Tue, 28 Apr 2020 12:28:26 -0400 Subject: [PATCH 02/52] adding gibbs sampler to alevin --- src/CollapsedCellOptimizer.cpp | 245 ++++++++++++++++++++++++++++++++- 1 file changed, 244 insertions(+), 1 deletion(-) diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index d08ff791e..032a58e1f 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -1,6 +1,9 @@ -#include "CollapsedCellOptimizer.hpp" #include #include +#include + +#include "CollapsedCellOptimizer.hpp" +#include "ezETAProgressBar.hpp" CollapsedCellOptimizer::CollapsedCellOptimizer() {} @@ -216,6 +219,216 @@ bool runPerCellEM(double& totalNumFrags, size_t numGenes, return true; } +bool runGibbsSamples(size_t numGenes, + CollapsedCellOptimizer::SerialVecType& geneAlphas, + std::vector& sampleMean, + std::vector& sampleVariance, + std::vector& salmonEqclasses, + std::shared_ptr& jointlog, + uint32_t numSamples, + std::vector>& sampleEstimates, + bool quiet = true){ + + + constexpr double minEQClassWeight = std::numeric_limits::denorm_min(); + constexpr double minWeight = std::numeric_limits::denorm_min(); + + uint32_t numInternalRounds{16}; + size_t numClasses = salmonEqclasses.size(); + + // gibbs related + CollapsedCellOptimizer::SerialVecType alphas(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType mean(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType squareMean(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType alphasIn(numGenes, 0.0); + CollapsedCellOptimizer::SerialVecType alphasInit(numGenes, 0.0); + + //extracting weight of eqclasses for making discrete distribution + uint32_t totalNumFrags = 0; + std::vector eqCounts; + + std::vector active(numGenes, false); + for (auto& eqclass: salmonEqclasses) { + for (size_t j = 0; j < eqclass.labels.size(); ++j) { + auto geneIdx = eqclass.labels[j]; + active[geneIdx] = true; + } + totalNumFrags += eqclass.count; + eqCounts.emplace_back(eqclass.count); + } + + // make active list (genes that are present in equivalence classes) + std::vector activeList; + activeList.reserve(numGenes); + for (size_t i = 0; i < numGenes; ++i) { + if (active[i]) { + activeList.push_back(i); + } + alphasIn[i] = geneAlphas[i]; + alphasInit[i] = geneAlphas[i]; + } + + double prior = 1e-3; + std::vector priorAlphas(numGenes, prior); + + std::vector offsetMap(numClasses, 0); + size_t countMapSize{0}; + for (size_t i = 0; i < numClasses ; ++i){ + countMapSize += salmonEqclasses[i].labels.size(); + if(i < numClasses - 1) { + offsetMap[i + 1] = countMapSize; + } + } + + // hold the estimated counts, active list + std::vector mu(numGenes, 0.0); + std::vector countMap(countMapSize, 0); + std::vector probMap(countMapSize, 0.0); + + uint32_t nchains{1}; + if (numSamples >= 50) { + nchains = 2; + } + if (numSamples >= 100) { + nchains = 4; + } + if (numSamples >= 200) { + nchains = 8; + } + + std::random_device rd; + std::mt19937 gen(rd()); + + std::vector newChainIter{0}; + if (nchains > 1) { + auto step = numSamples / nchains; + for (size_t i = 1; i < nchains; ++i) { + newChainIter.push_back(i * step); + } + } + auto nextChainStart = newChainIter.begin(); + + // For each sample this thread should generate + std::unique_ptr pbar{nullptr}; + if (!quiet) { + pbar.reset(new ez::ezETAProgressBar(numSamples)); + pbar->start(); + } + + for (size_t sampleID = 0; sampleID < numSamples; ++sampleID) { + + if (pbar) { + ++(*pbar); + } + // If we should start a new chain here, then do it! + if (nextChainStart < newChainIter.end() and sampleID == *nextChainStart) { + alphasIn = alphasInit; + ++nextChainStart; + } + + // Since for single cell data we don't estimate the exact fragment length + // essentially it will be treated as a single end read, there for from the + // definition of effective length, + // l_e = l_i - l_f + 1 + // we assume l_e = 1 + // The rest of the gibbs sample would pretty much follow from + // the principle of bulk RNA-seq + + // The mean transcript fraction are sampled from + // ~ Gam( prior[i] + geneAlphas[i], \Beta + 1 ) + // Given these transcript fractions, the reads are + // re-assigned within each equivalence class by sampling from + // a multinomial distribution according to these means + for (size_t roundIdx = 0; roundIdx < numInternalRounds; ++roundIdx) { + double beta = 0.1; + double norm = 0.0; + + // first phase: Calculate mean transcript fraction from Gamma + for(size_t activeIdx = 0; activeIdx < activeList.size(); ++activeIdx) { + auto i = activeList[activeIdx]; + double ci = static_cast(alphas[i] + priorAlphas[i]); + std::gamma_distribution d(ci, 1.0 / (beta + 1.0)); + mu[i] = d(gen); + alphas[i] = 0.0 ; + } + + // second phase: sample from the trandcript fractions + // re-assign them back to equivalence classes + for (size_t eqId = 0; eqId < numClasses; ++eqId) { + size_t offset = offsetMap[eqId]; + size_t classCount = salmonEqclasses[eqId].count ; + const std::vector& geneLabels = salmonEqclasses[eqId].labels; + size_t groupSize = geneLabels.size(); + + double muSum{0.0}; + double denom{0.0}; + if (groupSize > 1) { + double uniformWeight = 1.0 / static_cast(groupSize); + for (size_t i = 0; i < groupSize; ++i) { + auto gid = geneLabels[i]; + size_t globalIndex = offset + i; + probMap[globalIndex] = (1000.0 * mu[gid]) * uniformWeight ; + muSum += probMap[globalIndex]; + denom += probMap[globalIndex]; + } + // we might be working with tiny values and + // the denominator can become very small + if (denom <= minEQClassWeight) { + denom = 0.0; + muSum = 0.0; + for (size_t i = 0; i < groupSize; ++i) { + auto gid = geneLabels[i]; + size_t globalIndex = offset + i; + probMap[globalIndex] = 1.0 ; + muSum += probMap[globalIndex]; + denom += probMap[globalIndex]; + } + } + // Assuming previous step worked + // re-sample from a multinomial + // fill in only subpart of alpha + if (denom > minEQClassWeight) { + std::discrete_distribution dist(probMap.begin() + offset, + probMap.begin() + offset + groupSize + ); + for (size_t s = 0; s < classCount ; ++s){ + auto ind = dist(gen); + ++alphas[geneLabels[ind]]; + } + }else{ + jointlog->warn("the probabilities are too small " + "Make sure you ran salmon correclty."); + jointlog->flush(); + } + }else{ + auto gid = geneLabels[0]; + alphas[gid] += static_cast(classCount); + } + } + } + + // internal rounds are done + // TODO: can extrapolate the counts + // but as effective length is not at play, but imo it + // would hardly affect anything + + // updated values are in alpha let's put them in the estimation vector + // Calculate mean and square mean for later + for (size_t i=0; i& salmonEqclasses, @@ -680,6 +893,36 @@ void optimizeCell(std::vector& trueBarcodes, salmon::utils::incLoop(totalDedupCounts, totalCount); totalExpGeneCounts += totalExpGenes; + uint32_t numGibbsSamples = 10 ; + if ( numGibbsSamples > 0 ) { + std::vector> sampleEstimates; + std::vector sampleVariance(numGenes, 0.0); + std::vector sampleMean(numGenes, 0.0); + + bool isGibbsOk = runGibbsSamples( + numGenes, + geneAlphas, + sampleMean, + sampleVariance, + salmonEqclasses, + jointlog, + numGibbsSamples, + sampleEstimates + ); + + if( not isGibbsOk or (sampleEstimates.size()!=numGibbsSamples)){ + jointlog->error("Gibbs failed failed \n" + "Please Report this on github."); + jointlog->flush(); + std::exit(74); + } + + // write the abundance for the cell + gzw.writeSparseBootstraps( trueBarcodeStr, + sampleMean, sampleVariance, + true, sampleEstimates); + }//end-gibbs-if + if ( numBootstraps > 0 ){ std::vector> sampleEstimates; std::vector bootVariance(numGenes, 0.0); From 6e30ed5f21a072cf935bee5afc6d5b6cacfcfe08 Mon Sep 17 00:00:00 2001 From: hiraksarkar Date: Tue, 28 Apr 2020 13:30:40 -0400 Subject: [PATCH 03/52] gibbs samples added --- include/AlevinOpts.hpp | 2 ++ include/CollapsedCellOptimizer.hpp | 3 ++- include/SalmonDefaults.hpp | 1 + src/AlevinUtils.cpp | 1 + src/CollapsedCellOptimizer.cpp | 7 ++++--- src/ProgramOptionsGenerator.cpp | 4 ++++ 6 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/AlevinOpts.hpp b/include/AlevinOpts.hpp index cf08e041f..95f6da132 100644 --- a/include/AlevinOpts.hpp +++ b/include/AlevinOpts.hpp @@ -80,6 +80,8 @@ struct AlevinOpts { uint32_t maxNumBarcodes; // number of bootstraps to perform uint32_t numBootstraps; + // number of gibbs samples to perform + uint32_t numGibbsSamples; // force the number of cells uint32_t forceCells; // define a close upper bound on expected number of cells diff --git a/include/CollapsedCellOptimizer.hpp b/include/CollapsedCellOptimizer.hpp index 4ae0c08a4..db5acae08 100644 --- a/include/CollapsedCellOptimizer.hpp +++ b/include/CollapsedCellOptimizer.hpp @@ -80,7 +80,8 @@ void optimizeCell(std::vector& trueBarcodes, bool quiet, tbb::atomic& totalDedupCounts, tbb::atomic& totalExpGeneCounts, double priorWeight, spp::sparse_hash_map& txpToGeneMap, - uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, + uint32_t numGenes, uint32_t umiLength, + uint32_t numBootstraps, uint32_t numGibbsSamples, bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps, bool initUniform, CFreqMapT& freqCounter, bool dumpArboFragCounts, spp::sparse_hash_set& mRnaGenes, diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index dff3d66e5..6fb49dd0e 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -148,6 +148,7 @@ namespace defaults { constexpr const bool debug{true}; constexpr const uint32_t trimRight{0}; constexpr const uint32_t numBootstraps{0}; + constexpr const uint32_t numGibbsSamples{0}; constexpr const uint32_t lowRegionMinNumBarcodes{200}; constexpr const uint32_t maxNumBarcodes{100000}; constexpr const uint32_t expectCells{0}; diff --git a/src/AlevinUtils.cpp b/src/AlevinUtils.cpp index 3dda707b2..eec578d7d 100644 --- a/src/AlevinUtils.cpp +++ b/src/AlevinUtils.cpp @@ -580,6 +580,7 @@ namespace alevin { aopt.dumpUmiGraph = vm["dumpUmiGraph"].as(); aopt.trimRight = vm["trimRight"].as(); aopt.numBootstraps = vm["numCellBootstraps"].as(); + aopt.numGibbsSamples = vm["numCellGibbsSamples"].as(); aopt.lowRegionMinNumBarcodes = vm["lowRegionMinNumBarcodes"].as(); aopt.maxNumBarcodes = vm["maxNumBarcodes"].as(); aopt.freqThreshold = vm["freqThreshold"].as(); diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index 032a58e1f..7fb513bde 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -559,7 +559,8 @@ void optimizeCell(std::vector& trueBarcodes, bool quiet, tbb::atomic& totalDedupCounts, tbb::atomic& totalExpGeneCounts, double priorWeight, spp::sparse_hash_map& txpToGeneMap, - uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, + uint32_t numGenes, uint32_t umiLength, + uint32_t numBootstraps, uint32_t numGibbsSamples, bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps, bool initUniform, CFreqMapT& freqCounter, bool dumpArborescences, spp::sparse_hash_set& mRnaGenes, @@ -892,8 +893,7 @@ void optimizeCell(std::vector& trueBarcodes, // maintaining count for total number of predicted UMI salmon::utils::incLoop(totalDedupCounts, totalCount); totalExpGeneCounts += totalExpGenes; - - uint32_t numGibbsSamples = 10 ; + if ( numGibbsSamples > 0 ) { std::vector> sampleEstimates; std::vector sampleVariance(numGenes, 0.0); @@ -1263,6 +1263,7 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, numGenes, aopt.protocol.umiLength, aopt.numBootstraps, + aopt.numGibbsSamples, aopt.naiveEqclass, aopt.dumpUmiGraph, aopt.dumpfeatures, diff --git a/src/ProgramOptionsGenerator.cpp b/src/ProgramOptionsGenerator.cpp index 169febc75..2483411d6 100644 --- a/src/ProgramOptionsGenerator.cpp +++ b/src/ProgramOptionsGenerator.cpp @@ -390,6 +390,10 @@ namespace salmon { "numCellBootstraps",po::value()->default_value(alevin::defaults::numBootstraps), "Generate mean and variance for cell x gene matrix quantification" " estimates.") + ( + "numCellGibbsSamples",po::value()->default_value(alevin::defaults::numGibbsSamples), + "Generate mean and variance for cell x gene matrix quantification by running gibbs chain" + " estimates.") ( "forceCells",po::value()->default_value(alevin::defaults::forceCells), "Explicitly specify the number of cells.") From 73d38eebf0398779054670ab443b5df8c9c8ab1b Mon Sep 17 00:00:00 2001 From: hiraksarkar Date: Tue, 28 Apr 2020 14:01:12 -0400 Subject: [PATCH 04/52] option `--dumpCellEq` for dumping deduplicated cell level equivalence classes --- include/AlevinOpts.hpp | 3 +++ include/CollapsedCellOptimizer.hpp | 3 ++- include/SalmonDefaults.hpp | 1 + src/AlevinUtils.cpp | 1 + src/CollapsedCellOptimizer.cpp | 8 ++++---- src/ProgramOptionsGenerator.cpp | 3 +++ 6 files changed, 14 insertions(+), 5 deletions(-) diff --git a/include/AlevinOpts.hpp b/include/AlevinOpts.hpp index 95f6da132..10855b1bb 100644 --- a/include/AlevinOpts.hpp +++ b/include/AlevinOpts.hpp @@ -40,6 +40,9 @@ struct AlevinOpts { bool dumpBFH; // dump per cell level umi-graph bool dumpUmiGraph; + // dump per cell level de-duplicated + // equivalence class + bool dumpCellEq; //Stop progress sumps bool quiet; //flag for deduplication diff --git a/include/CollapsedCellOptimizer.hpp b/include/CollapsedCellOptimizer.hpp index db5acae08..6ca46bb24 100644 --- a/include/CollapsedCellOptimizer.hpp +++ b/include/CollapsedCellOptimizer.hpp @@ -82,7 +82,8 @@ void optimizeCell(std::vector& trueBarcodes, spp::sparse_hash_map& txpToGeneMap, uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, uint32_t numGibbsSamples, - bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps, + bool naiveEqclass, bool dumpUmiGraph, + bool dumpCellEq, bool useAllBootstraps, bool initUniform, CFreqMapT& freqCounter, bool dumpArboFragCounts, spp::sparse_hash_set& mRnaGenes, spp::sparse_hash_set& rRnaGenes, diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index 6fb49dd0e..25402f86b 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -143,6 +143,7 @@ namespace defaults { constexpr const bool dumpFeatures{false}; constexpr const bool dumpBFH{false}; constexpr const bool dumpUmiGraph{false}; + constexpr const bool dumpCellEq{false}; constexpr const bool dumpMtx{false}; constexpr const bool noEM{false}; constexpr const bool debug{true}; diff --git a/src/AlevinUtils.cpp b/src/AlevinUtils.cpp index eec578d7d..bae0be0e9 100644 --- a/src/AlevinUtils.cpp +++ b/src/AlevinUtils.cpp @@ -578,6 +578,7 @@ namespace alevin { aopt.dumpBarcodeEq = vm["dumpBarcodeEq"].as(); aopt.dumpBFH = vm["dumpBfh"].as(); aopt.dumpUmiGraph = vm["dumpUmiGraph"].as(); + aopt.dumpCellEq = vm["dumpCellEq"].as(); aopt.trimRight = vm["trimRight"].as(); aopt.numBootstraps = vm["numCellBootstraps"].as(); aopt.numGibbsSamples = vm["numCellGibbsSamples"].as(); diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index 7fb513bde..bf5df8f2f 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -561,7 +561,8 @@ void optimizeCell(std::vector& trueBarcodes, spp::sparse_hash_map& txpToGeneMap, uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, uint32_t numGibbsSamples, - bool naiveEqclass, bool dumpUmiGraph, bool useAllBootstraps, + bool naiveEqclass, bool dumpUmiGraph, + bool dumpCellEq, bool useAllBootstraps, bool initUniform, CFreqMapT& freqCounter, bool dumpArborescences, spp::sparse_hash_set& mRnaGenes, spp::sparse_hash_set& rRnaGenes, @@ -663,9 +664,7 @@ void optimizeCell(std::vector& trueBarcodes, std::exit(74); } - - bool writeDedupEqClasses{true}; - if ( writeDedupEqClasses ){ + if ( dumpCellEq ){ std::vector> labelsEq ; std::vector countsEq ; labelsEq.resize(salmonEqclasses.size()); @@ -1266,6 +1265,7 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, aopt.numGibbsSamples, aopt.naiveEqclass, aopt.dumpUmiGraph, + aopt.dumpCellEq, aopt.dumpfeatures, aopt.initUniform, std::ref(freqCounter), diff --git a/src/ProgramOptionsGenerator.cpp b/src/ProgramOptionsGenerator.cpp index 2483411d6..66e82d465 100644 --- a/src/ProgramOptionsGenerator.cpp +++ b/src/ProgramOptionsGenerator.cpp @@ -445,6 +445,9 @@ namespace salmon { ( "dumpUmiGraph", po::bool_switch()->default_value(alevin::defaults::dumpUmiGraph), "dump the per cell level Umi Graph.") + ( + "dumpCellEq", po::bool_switch()->default_value(alevin::defaults::dumpCellEq), + "dump the per cell level deduplicated equivalence classes.") ( "dumpFeatures", po::bool_switch()->default_value(alevin::defaults::dumpFeatures), "Dump features for whitelist and downstream analysis.") From 204a842f4218fce79faa379928ca579321d594c1 Mon Sep 17 00:00:00 2001 From: hiraksarkar Date: Tue, 28 Apr 2020 16:33:01 -0400 Subject: [PATCH 05/52] either of bootstrap/gibbs can be turned on --- src/AlevinUtils.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/AlevinUtils.cpp b/src/AlevinUtils.cpp index bae0be0e9..e320627c0 100644 --- a/src/AlevinUtils.cpp +++ b/src/AlevinUtils.cpp @@ -604,6 +604,12 @@ namespace alevin { return false; } + if (aopt.numGibbsSamples > 0 and aopt.numBootstraps > 0) { + aopt.jointLog->error("Either of --numCellGibbsSamples or --numCellBootstraps " + "can be used"); + return false; + } + if ( aopt.numBootstraps > 0 and aopt.noEM ) { aopt.jointLog->error("cannot perform bootstrapping with noEM option."); return false; From b9f5a8fdf87e8d1b84eb023d08c2a665cb2bb266 Mon Sep 17 00:00:00 2001 From: hiraksarkar Date: Tue, 28 Apr 2020 20:07:31 -0400 Subject: [PATCH 06/52] gibbs sample generation needs EM to be run --- src/CollapsedCellOptimizer.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index bf5df8f2f..9fdc15dae 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -676,8 +676,8 @@ void optimizeCell(std::vector& trueBarcodes, gzw.writeDedupCellEQVec(trueBarcodeIdx, labelsEq, countsEq, true); } - if ( numBootstraps and noEM ) { - jointlog->error("Cannot perform bootstrapping with noEM"); + if ( (numBootstraps and noEM) or (numGibbsSamples and noEM) ) { + jointlog->error("Cannot perform bootstrapping/gibbs with noEM"); jointlog->flush(); exit(1); } From d488552b8bb84accc717b52688e18c989941577a Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Tue, 28 Apr 2020 22:54:52 -0400 Subject: [PATCH 07/52] MASSIVELY simplify version checker, add cpp-httplib --- include/SalmonConfig.hpp | 6 +- include/VersionChecker.hpp | 31 - include/httplib.hpp | 5123 ++++++++++++++++++++++++++++++++++++ src/VersionChecker.cpp | 193 +- 4 files changed, 5141 insertions(+), 212 deletions(-) create mode 100644 include/httplib.hpp diff --git a/include/SalmonConfig.hpp b/include/SalmonConfig.hpp index 27ea70b15..2ead95ae9 100644 --- a/include/SalmonConfig.hpp +++ b/include/SalmonConfig.hpp @@ -26,9 +26,9 @@ namespace salmon { constexpr char majorVersion[] = "1"; -constexpr char minorVersion[] = "2"; -constexpr char patchVersion[] = "1"; -constexpr char version[] = "1.2.1"; +constexpr char minorVersion[] = "3"; +constexpr char patchVersion[] = "0"; +constexpr char version[] = "1.3.0"; constexpr uint32_t indexVersion = 5; constexpr char requiredQuasiIndexVersion[] = "p7"; } // namespace salmon diff --git a/include/VersionChecker.hpp b/include/VersionChecker.hpp index 0ae49039f..d999472c0 100644 --- a/include/VersionChecker.hpp +++ b/include/VersionChecker.hpp @@ -11,41 +11,10 @@ #ifndef VERSION_CHECKER_HPP #define VERSION_CHECKER_HPP -#include -#include -#include #include -#include -#include #include #include -using boost::asio::ip::tcp; - -class VersionChecker { -public: - VersionChecker(boost::asio::io_service& io_service, const std::string& server, - const std::string& path); - std::string message(); - -private: - void cancel_upgrade_check(const boost::system::error_code& err); - void handle_resolve(const boost::system::error_code& err, - tcp::resolver::iterator endpoint_iterator); - void handle_connect(const boost::system::error_code& err); - void handle_write_request(const boost::system::error_code& err); - void handle_read_status_line(const boost::system::error_code& err); - void handle_read_headers(const boost::system::error_code& err); - void handle_read_content(const boost::system::error_code& err); - - tcp::resolver resolver_; - tcp::socket socket_; - boost::asio::streambuf request_; - boost::asio::streambuf response_; - boost::asio::deadline_timer deadline_; - std::stringstream messageStream_; -}; - std::string getVersionMessage(); #endif // VERSION_CHECKER_HPP \ No newline at end of file diff --git a/include/httplib.hpp b/include/httplib.hpp new file mode 100644 index 000000000..03ef11a03 --- /dev/null +++ b/include/httplib.hpp @@ -0,0 +1,5123 @@ +// +// httplib.h +// +// Copyright (c) 2020 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(1u, std::thread::hardware_concurrency() - 1)) +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = int; +#endif + +#if _MSC_VER < 1900 +#define snprintf _snprintf_s +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +#ifdef _MSC_VER +#pragma comment(lib, "ws2_32.lib") +#endif + +#ifndef strcasecmp +#define strcasecmp _stricmp +#endif // strcasecmp + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#include +#include +#include +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include + +using socket_t = int; +#define INVALID_SOCKET (-1) +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#include +#include +#include +#include + +#include +#include +#include + +// #if OPENSSL_VERSION_NUMBER < 0x1010100fL +// #error Sorry, OpenSSL versions prior to 1.1.1 are not supported +// #endif + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +#include +inline const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *asn1) { + return M_ASN1_STRING_data(asn1); +} +#endif +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +struct ci { + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare( + s1.begin(), s1.end(), s2.begin(), s2.end(), + [](char c1, char c2) { return ::tolower(c1) < ::tolower(c2); }); + } +}; + +} // namespace detail + +using Headers = std::multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() = default; + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function done; + std::function is_writable; +}; + +using ContentProvider = + std::function; + +using ContentReceiver = + std::function; + +using MultipartContentHeader = + std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader muitlpart_reader) + : reader_(reader), muitlpart_reader_(muitlpart_reader) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return muitlpart_reader_(header, receiver); + } + + bool operator()(ContentReceiver receiver) const { return reader_(receiver); } + + Reader reader_; + MultipartReader muitlpart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + + // for client + size_t redirect_count = CPPHTTPLIB_REDIRECT_MAX_COUNT; + ResponseHandler response_handler; + ContentReceiver content_receiver; + Progress progress; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl; +#endif + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + bool has_param(const char *key) const; + std::string get_param_value(const char *key, size_t id = 0) const; + size_t get_param_value_count(const char *key) const; + + bool is_multipart_form_data() const; + + bool has_file(const char *key) const; + MultipartFormData get_file_value(const char *key) const; + + // private members... + size_t content_length; + ContentProvider content_provider; +}; + +struct Response { + std::string version; + int status = -1; + Headers headers; + std::string body; + + bool has_header(const char *key) const; + std::string get_header_value(const char *key, size_t id = 0) const; + size_t get_header_value_count(const char *key) const; + void set_header(const char *key, const char *val); + void set_header(const char *key, const std::string &val); + + void set_redirect(const char *url, int status = 302); + void set_content(const char *s, size_t n, const char *content_type); + void set_content(std::string s, const char *content_type); + + void set_content_provider( + size_t length, + std::function + provider, + std::function resource_releaser = [] {}); + + void set_chunked_content_provider( + std::function provider, + std::function resource_releaser = [] {}); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser) { + content_provider_resource_releaser(); + } + } + + // private members... + size_t content_length = 0; + ContentProvider content_provider; + std::function content_provider_resource_releaser; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + + template + ssize_t write_format(const char *fmt, const Args &... args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual void enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle(){}; +}; + +class ThreadPool : public TaskQueue { +public: + explicit ThreadPool(size_t n) : shutdown_(false) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + void enqueue(std::function fn) override { + std::unique_lock lock(mutex_); + jobs_.push_back(fn); + cond_.notify_one(); + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +class Server { +public: + using Handler = std::function; + using HandlerWithContentReader = std::function; + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const char *pattern, Handler handler); + Server &Post(const char *pattern, Handler handler); + Server &Post(const char *pattern, HandlerWithContentReader handler); + Server &Put(const char *pattern, Handler handler); + Server &Put(const char *pattern, HandlerWithContentReader handler); + Server &Patch(const char *pattern, Handler handler); + Server &Patch(const char *pattern, HandlerWithContentReader handler); + Server &Delete(const char *pattern, Handler handler); + Server &Delete(const char *pattern, HandlerWithContentReader handler); + Server &Options(const char *pattern, Handler handler); + + [[deprecated]] bool set_base_dir(const char *dir, + const char *mount_point = nullptr); + bool set_mount_point(const char *mount_point, const char *dir); + bool remove_mount_point(const char *mount_point); + void set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime); + void set_file_request_handler(Handler handler); + + void set_error_handler(Handler handler); + void set_logger(Logger logger); + + void set_expect_100_continue_handler(Expect100ContinueHandler handler); + + void set_keep_alive_max_count(size_t count); + void set_read_timeout(time_t sec, time_t usec); + void set_payload_max_length(size_t length); + + bool bind_to_port(const char *host, int port, int socket_flags = 0); + int bind_to_any_port(const char *host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const char *host, int port, int socket_flags = 0); + + bool is_running() const; + void stop(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request); + + size_t keep_alive_max_count_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + size_t payload_max_length_; + +private: + using Handlers = std::vector>; + using HandlersForContentReader = + std::vector>; + + socket_t create_server_socket(const char *host, int port, + int socket_flags) const; + int bind_internal(const char *host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(Request &req, Response &res, bool head = false); + bool dispatch_request(Request &req, Response &res, Handlers &handlers); + bool dispatch_request_for_content_reader(Request &req, Response &res, + ContentReader content_reader, + HandlersForContentReader &handlers); + + bool parse_request_line(const char *s, Request &req); + bool write_response(Stream &strm, bool last_connection, const Request &req, + Response &res); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver); + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_; + std::atomic svr_sock_; + std::vector> base_dirs_; + std::map file_extension_and_mimetype_map_; + Handler file_request_handler_; + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + Handler error_handler_; + Logger logger_; + Expect100ContinueHandler expect_100_continue_handler_; +}; + +class Client { +public: + explicit Client(const std::string &host, int port = 80, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); + + virtual ~Client(); + + virtual bool is_valid() const; + + std::shared_ptr Get(const char *path); + + std::shared_ptr Get(const char *path, const Headers &headers); + + std::shared_ptr Get(const char *path, Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, + Progress progress); + + std::shared_ptr Get(const char *path, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver); + + std::shared_ptr + Get(const char *path, ContentReceiver content_receiver, Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + + std::shared_ptr Get(const char *path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress); + + std::shared_ptr Head(const char *path); + + std::shared_ptr Head(const char *path, const Headers &headers); + + std::shared_ptr Post(const char *path); + + std::shared_ptr Post(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Post(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Post(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Post(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Post(const char *path, const Params ¶ms); + + std::shared_ptr Post(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr Post(const char *path, + const MultipartFormDataItems &items); + + std::shared_ptr Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items); + + std::shared_ptr Put(const char *path); + + std::shared_ptr Put(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Put(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Put(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Put(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Put(const char *path, const Params ¶ms); + + std::shared_ptr Put(const char *path, const Headers &headers, + const Params ¶ms); + + std::shared_ptr Patch(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Patch(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Patch(const char *path, size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Patch(const char *path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const char *content_type); + + std::shared_ptr Delete(const char *path); + + std::shared_ptr Delete(const char *path, const std::string &body, + const char *content_type); + + std::shared_ptr Delete(const char *path, const Headers &headers); + + std::shared_ptr Delete(const char *path, const Headers &headers, + const std::string &body, + const char *content_type); + + std::shared_ptr Options(const char *path); + + std::shared_ptr Options(const char *path, const Headers &headers); + + bool send(const Request &req, Response &res); + + bool send(const std::vector &requests, + std::vector &responses); + + void stop(); + + void set_timeout_sec(time_t timeout_sec); + + void set_read_timeout(time_t sec, time_t usec); + + void set_keep_alive_max_count(size_t count); + + void set_basic_auth(const char *username, const char *password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const char *username, const char *password); +#endif + + void set_follow_location(bool on); + + void set_compress(bool on); + + void set_interface(const char *intf); + + void set_proxy(const char *host, int port); + + void set_proxy_basic_auth(const char *username, const char *password); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const char *username, const char *password); +#endif + + void set_logger(Logger logger); + +protected: + bool process_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); + + std::atomic sock_; + + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t timeout_sec_ = 300; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + + std::string basic_auth_username_; + std::string basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool follow_location_ = false; + + bool compress_ = false; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + + Logger logger_; + + void copy_settings(const Client &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + timeout_sec_ = rhs.timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + keep_alive_max_count_ = rhs.keep_alive_max_count_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + follow_location_ = rhs.follow_location_; + compress_ = rhs.compress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif + logger_ = rhs.logger_; + } + +private: + socket_t create_client_socket() const; + bool read_response_line(Stream &strm, Response &res); + bool write_request(Stream &strm, const Request &req, bool last_connection); + bool redirect(const Request &req, Response &res); + bool handle_request(Stream &strm, const Request &req, Response &res, + bool last_connection, bool &connection_close); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool connect(socket_t sock, Response &res, bool &error); +#endif + + std::shared_ptr send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type); + + virtual bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback); + + virtual bool is_ssl() const; +}; + +inline void Get(std::vector &requests, const char *path, + const Headers &headers) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + requests.emplace_back(std::move(req)); +} + +inline void Get(std::vector &requests, const char *path) { + Get(requests, path, Headers()); +} + +inline void Post(std::vector &requests, const char *path, + const Headers &headers, const std::string &body, + const char *content_type) { + Request req; + req.method = "POST"; + req.path = path; + req.headers = headers; + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + requests.emplace_back(std::move(req)); +} + +inline void Post(std::vector &requests, const char *path, + const std::string &body, const char *content_type) { + Post(requests, path, Headers(), body, content_type); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + ~SSLServer() override; + + bool is_valid() const override; + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient : public Client { +public: + explicit SSLClient(const std::string &host, int port = 443, + const std::string &client_cert_path = std::string(), + const std::string &client_key_path = std::string()); + + SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_path(const char *ca_ceert_file_path, + const char *ca_cert_dir_path = nullptr); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + + void enable_server_certificate_verification(bool enabled); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) override; + bool is_ssl() const override; + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::vector host_components_; + + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + X509_STORE *ca_cert_store_ = nullptr; + bool server_certificate_verification_ = false; + long verify_result_ = 0; +}; +#endif + +// ---------------------------------------------------------------------------- + +/* + * Implementation + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + int v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + const char *charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = (code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + int val = 0; + int valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], size); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +template void split(const char *b, const char *e, char d, Fn fn) { + int i = 0; + int beg = 0; + + while (e ? (b + i != e) : (b[i] != '\0')) { + if (b[i] == d) { + fn(&b[beg], &b[i]); + beg = i + 1; + } + i++; + } + + if (i) { fn(&b[beg], &b[i]); } +} + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + + const char *ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } + } + + size_t size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } + } + + bool end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; + } + + bool getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } + } + + return true; + } + +private: + void append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } + } + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +inline int select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return HANDLE_EINTR(poll, &pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); +#endif +} + +inline int select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return HANDLE_EINTR(poll, &pfd_read, 1, timeout); +#else + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); +#endif +} + +inline bool wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + int poll_res = HANDLE_EINTR(poll, &pfd_read, 1, timeout); + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + int error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + return res >= 0 && !error; + } + return false; +#else + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + if (select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv) > 0 && + (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + int error = 0; + socklen_t len = sizeof(error); + return getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len) >= 0 && + !error; + } + return false; +#endif +} + +class SocketStream : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream : public Stream { +public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; +}; +#endif + +class BufferStream : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +template +inline bool process_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + assert(keep_alive_max_count > 0); + + auto ret = false; + + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { // keep_alive_max_count is 0 or 1 + SocketStream strm(sock, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(strm, true, dummy_connection_close); + } + + return ret; +} + +template +inline bool process_and_close_socket(bool is_client_request, socket_t sock, + size_t keep_alive_max_count, + time_t read_timeout_sec, + time_t read_timeout_usec, T callback) { + auto ret = process_socket(is_client_request, sock, keep_alive_max_count, + read_timeout_sec, read_timeout_usec, callback); + close_socket(sock); + return ret; +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +template +socket_t create_socket(const char *host, int port, Fn fn, + int socket_flags = 0) { +#ifdef _WIN32 +#define SO_SYNCHRONOUS_NONALERT 0x20 +#define SO_OPENTYPE 0x7008 + + int opt = SO_SYNCHRONOUS_NONALERT; + setsockopt(INVALID_SOCKET, SOL_SOCKET, SO_OPENTYPE, (char *)&opt, + sizeof(opt)); +#endif + + // Get address info + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = socket_flags; + hints.ai_protocol = 0; + + auto service = std::to_string(port); + + if (getaddrinfo(host, service.c_str(), &hints, &result)) { + return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, + nullptr, 0, WSA_FLAG_NO_HANDLE_INHERIT); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { continue; } +#endif + + // Make 'reuse address' option available + int yes = 1; + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&yes), + sizeof(yes)); +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, reinterpret_cast(&yes), + sizeof(yes)); +#endif + + // bind or connect + if (fn(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const char *host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host, "0", &hints, &result)) { return false; } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +inline std::string if2ip(const std::string &ifn) { +#ifndef _WIN32 + struct ifaddrs *ifap; + getifaddrs(&ifap); + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } + } + } + freeifaddrs(ifap); +#endif + return std::string(); +} + +inline socket_t create_client_socket(const char *host, int port, + time_t timeout_sec, + const std::string &intf) { + return create_socket( + host, port, [&](socket_t sock, struct addrinfo &ai) -> bool { + if (!intf.empty()) { + auto ip = if2ip(intf); + if (ip.empty()) { ip = intf; } + if (!bind_ip_address(sock, ip.c_str())) { return false; } + } + + set_nonblocking(sock, true); + + auto ret = + ::connect(sock, ai.ai_addr, static_cast(ai.ai_addrlen)); + if (ret < 0) { + if (is_connection_error() || + !wait_until_socket_is_ready(sock, timeout_sec, 0)) { + close_socket(sock); + return false; + } + } + + set_nonblocking(sock, false); + return true; + }); +} + +inline void get_remote_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, + int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } + + std::array ipstr{}; + if (!getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + ip = ipstr.data(); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { + get_remote_ip_and_port(addr, addr_len, ip, port); + } +} + +inline const char * +find_content_type(const std::string &path, + const std::map &user_data) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second.c_str(); } + + if (ext == "txt") { + return "text/plain"; + } else if (ext == "html" || ext == "htm") { + return "text/html"; + } else if (ext == "css") { + return "text/css"; + } else if (ext == "jpeg" || ext == "jpg") { + return "image/jpg"; + } else if (ext == "png") { + return "image/png"; + } else if (ext == "gif") { + return "image/gif"; + } else if (ext == "svg") { + return "image/svg+xml"; + } else if (ext == "ico") { + return "image/x-icon"; + } else if (ext == "json") { + return "application/json"; + } else if (ext == "pdf") { + return "application/pdf"; + } else if (ext == "js") { + return "application/javascript"; + } else if (ext == "wasm") { + return "application/wasm"; + } else if (ext == "xml") { + return "application/xml"; + } else if (ext == "xhtml") { + return "application/xhtml+xml"; + } + return nullptr; +} + +inline const char *status_message(int status) { + switch (status) { + case 100: return "Continue"; + case 101: return "Switching Protocol"; + case 102: return "Processing"; + case 103: return "Early Hints"; + case 200: return "OK"; + case 201: return "Created"; + case 202: return "Accepted"; + case 203: return "Non-Authoritative Information"; + case 204: return "No Content"; + case 205: return "Reset Content"; + case 206: return "Partial Content"; + case 207: return "Multi-Status"; + case 208: return "Already Reported"; + case 226: return "IM Used"; + case 300: return "Multiple Choice"; + case 301: return "Moved Permanently"; + case 302: return "Found"; + case 303: return "See Other"; + case 304: return "Not Modified"; + case 305: return "Use Proxy"; + case 306: return "unused"; + case 307: return "Temporary Redirect"; + case 308: return "Permanent Redirect"; + case 400: return "Bad Request"; + case 401: return "Unauthorized"; + case 402: return "Payment Required"; + case 403: return "Forbidden"; + case 404: return "Not Found"; + case 405: return "Method Not Allowed"; + case 406: return "Not Acceptable"; + case 407: return "Proxy Authentication Required"; + case 408: return "Request Timeout"; + case 409: return "Conflict"; + case 410: return "Gone"; + case 411: return "Length Required"; + case 412: return "Precondition Failed"; + case 413: return "Payload Too Large"; + case 414: return "URI Too Long"; + case 415: return "Unsupported Media Type"; + case 416: return "Range Not Satisfiable"; + case 417: return "Expectation Failed"; + case 418: return "I'm a teapot"; + case 421: return "Misdirected Request"; + case 422: return "Unprocessable Entity"; + case 423: return "Locked"; + case 424: return "Failed Dependency"; + case 425: return "Too Early"; + case 426: return "Upgrade Required"; + case 428: return "Precondition Required"; + case 429: return "Too Many Requests"; + case 431: return "Request Header Fields Too Large"; + case 451: return "Unavailable For Legal Reasons"; + case 501: return "Not Implemented"; + case 502: return "Bad Gateway"; + case 503: return "Service Unavailable"; + case 504: return "Gateway Timeout"; + case 505: return "HTTP Version Not Supported"; + case 506: return "Variant Also Negotiates"; + case 507: return "Insufficient Storage"; + case 508: return "Loop Detected"; + case 510: return "Not Extended"; + case 511: return "Network Authentication Required"; + + default: + case 500: return "Internal Server Error"; + } +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline bool can_compress(const std::string &content_type) { + return !content_type.find("text/") || content_type == "image/svg+xml" || + content_type == "application/javascript" || + content_type == "application/json" || + content_type == "application/xml" || + content_type == "application/xhtml+xml"; +} + +inline bool compress(std::string &content) { + z_stream strm; + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + auto ret = deflateInit2(&strm, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY); + if (ret != Z_OK) { return false; } + + strm.avail_in = static_cast(content.size()); + strm.next_in = + const_cast(reinterpret_cast(content.data())); + + std::string compressed; + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + ret = deflate(&strm, Z_FINISH); + assert(ret != Z_STREAM_ERROR); + compressed.append(buff.data(), buff.size() - strm.avail_out); + } while (strm.avail_out == 0); + + assert(ret == Z_STREAM_END); + assert(strm.avail_in == 0); + + content.swap(compressed); + + deflateEnd(&strm); + return true; +} + +class decompressor { +public: + decompressor() { + std::memset(&strm, 0, sizeof(strm)); + strm.zalloc = Z_NULL; + strm.zfree = Z_NULL; + strm.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm, 32 + 15) == Z_OK; + } + + ~decompressor() { inflateEnd(&strm); } + + bool is_valid() const { return is_valid_; } + + template + bool decompress(const char *data, size_t data_length, T callback) { + int ret = Z_OK; + + strm.avail_in = static_cast(data_length); + strm.next_in = const_cast(reinterpret_cast(data)); + + std::array buff{}; + do { + strm.avail_out = buff.size(); + strm.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm, Z_NO_FLUSH); + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm); return false; + } + + if (!callback(buff.data(), buff.size() - strm.avail_out)) { + return false; + } + } while (strm.avail_out == 0); + + return ret == Z_OK || ret == Z_STREAM_END; + } + +private: + bool is_valid_; + z_stream strm; +}; +#endif + +inline bool has_header(const Headers &headers, const char *key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, const char *key, + size_t id = 0, const char *def = nullptr) { + auto it = headers.find(key); + std::advance(it, static_cast(id)); + if (it != headers.end()) { return it->second.c_str(); } + return def; +} + +inline uint64_t get_header_value_uint64(const Headers &headers, const char *key, + uint64_t def = 0) { + auto it = headers.find(key); + if (it != headers.end()) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } + } else { + continue; // Skip invalid line. + } + + // Skip trailing spaces and tabs. + auto end = line_reader.ptr() + line_reader.size() - 2; + while (line_reader.ptr() < end && (end[-1] == ' ' || end[-1] == '\t')) { + end--; + } + + // Horizontal tab and ' ' are considered whitespace and are ignored when on + // the left or right side of the header value: + // - https://stackoverflow.com/questions/50179659/ + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + static const std::regex re(R"(([^:]+):[\t ]*(.+))"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), end, m, re)) { + auto key = std::string(m[1]); + auto val = std::string(m[2]); + headers.emplace(key, val); + } + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n))) { return false; } + + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, ContentReceiver out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n < 0) { + return false; + } else if (n == 0) { + return true; + } + if (!out(buf, static_cast(n))) { return false; } + } + + return true; +} + +inline bool read_content_chunked(Stream &strm, ContentReceiver out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return false; } + if (chunk_len == ULONG_MAX) { return false; } + + if (chunk_len == 0) { break; } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n")) { break; } + + if (!line_reader.getline()) { return false; } + } + + if (chunk_len == 0) { + // Reader terminator after chunks + if (!line_reader.getline() || strcmp(line_reader.ptr(), "\r\n")) + return false; + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return !strcasecmp(get_header_value(headers, "Transfer-Encoding", 0, ""), + "chunked"); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiver receiver) { + + ContentReceiver out = [&](const char *buf, size_t n) { + return receiver(buf, n); + }; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor decompressor; + + std::string content_encoding = x.get_header_value("Content-Encoding"); + if (content_encoding.find("gzip") != std::string::npos || + content_encoding.find("deflate") != std::string::npos) { + if (!decompressor.is_valid()) { + status = 500; + return false; + } + + out = [&](const char *buf, size_t n) { + return decompressor.decompress( + buf, n, [&](const char *buf, size_t n) { return receiver(buf, n); }); + }; + } +#else + if (x.get_header_value("Content-Encoding") == "gzip") { + status = 415; + return false; + } +#endif + + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_uint64(x.headers, "Content-Length", 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, progress, out); + } + } + + if (!ret) { status = exceed_payload_max_length ? 413 : 400; } + + return ret; +} + +template +inline ssize_t write_headers(Stream &strm, const T &info, + const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : info.headers) { + if (x.first == "EXCEPTION_WHAT") { continue; } + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline ssize_t write_content(Stream &strm, ContentProvider content_provider, + size_t offset, size_t length) { + size_t begin_offset = offset; + size_t end_offset = offset + length; + while (offset < end_offset) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + offset += l; + written_length = strm.write(d, l); + }; + data_sink.done = [&](void) { written_length = -1; }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, end_offset - offset, data_sink); + if (written_length < 0) { return written_length; } + } + return static_cast(offset - begin_offset); +} + +template +inline ssize_t write_content_chunked(Stream &strm, + ContentProvider content_provider, + T is_shutting_down) { + size_t offset = 0; + auto data_available = true; + ssize_t total_written_length = 0; + while (data_available && !is_shutting_down()) { + ssize_t written_length = 0; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + data_available = l > 0; + offset += l; + + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(l) + "\r\n" + std::string(d, l) + "\r\n"; + written_length = strm.write(chunk); + }; + data_sink.done = [&](void) { + data_available = false; + written_length = strm.write("0\r\n\r\n"); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + content_provider(offset, 0, data_sink); + + if (written_length < 0) { return written_length; } + total_written_length += written_length; + } + return total_written_length; +} + +template +inline bool redirect(T &cli, const Request &req, Response &res, + const std::string &path) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count -= 1; + + if (res.status == 303 && (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res); + if (ret) { res = new_res; } + return ret; +} + +inline std::string encode_url(const std::string &s) { + std::string result; + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + int val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + int val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += detail::encode_url(it->second); + } + + return query; +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + split(&s[0], &s[s.size()], '&', [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + params.emplace(decode_url(key, true), decode_url(val, true)); + }); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto pos = content_type.find("boundary="); + if (pos == std::string::npos) { return false; } + + boundary = content_type.substr(pos + 9); + return true; +} + +inline bool parse_range_header(const std::string &s, Ranges &ranges) { + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + bool all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) return; + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; +} + +class MultipartFormDataParser { +public: + MultipartFormDataParser() = default; + + void set_boundary(std::string boundary) { boundary_ = std::move(boundary); } + + bool is_valid() const { return is_valid_; } + + template + bool parse(const char *buf, size_t n, T content_callback, U header_callback) { + static const std::regex re_content_type(R"(^Content-Type:\s*(.*?)\s*$)", + std::regex_constants::icase); + + static const std::regex re_content_disposition( + "^Content-Disposition:\\s*form-data;\\s*name=\"(.*?)\"(?:;\\s*filename=" + "\"(.*?)\")?\\s*$", + std::regex_constants::icase); + static const std::string dash_ = "--"; + static const std::string crlf_ = "\r\n"; + + buf_.append(buf, n); // TODO: performance improvement + + while (!buf_.empty()) { + switch (state_) { + case 0: { // Initial boundary + auto pattern = dash_ + boundary_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + auto pos = buf_.find(pattern); + if (pos != 0) { + is_done_ = true; + return false; + } + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_.find(crlf_); + while (pos != std::string::npos) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + is_done_ = false; + return false; + } + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 3; + break; + } + + auto header = buf_.substr(0, pos); + { + std::smatch m; + if (std::regex_match(header, m, re_content_type)) { + file_.content_type = m[1]; + } else if (std::regex_match(header, m, re_content_disposition)) { + file_.name = m[1]; + file_.filename = m[2]; + } + } + + buf_.erase(0, pos + crlf_.size()); + off_ += pos + crlf_.size(); + pos = buf_.find(crlf_); + } + break; + } + case 3: { // Body + { + auto pattern = crlf_ + dash_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos == std::string::npos) { pos = buf_.size(); } + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos; + buf_.erase(0, pos); + } + + { + auto pattern = crlf_ + dash_ + boundary_; + if (pattern.size() > buf_.size()) { return true; } + + auto pos = buf_.find(pattern); + if (pos != std::string::npos) { + if (!content_callback(buf_.data(), pos)) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pos + pattern.size(); + buf_.erase(0, pos + pattern.size()); + state_ = 4; + } else { + if (!content_callback(buf_.data(), pattern.size())) { + is_valid_ = false; + is_done_ = false; + return false; + } + + off_ += pattern.size(); + buf_.erase(0, pattern.size()); + } + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_.size()) { return true; } + if (buf_.find(crlf_) == 0) { + buf_.erase(0, crlf_.size()); + off_ += crlf_.size(); + state_ = 1; + } else { + auto pattern = dash_ + crlf_; + if (pattern.size() > buf_.size()) { return true; } + if (buf_.find(pattern) == 0) { + buf_.erase(0, pattern.size()); + off_ += pattern.size(); + is_valid_ = true; + state_ = 5; + } else { + is_done_ = true; + return true; + } + } + break; + } + case 5: { // Done + is_valid_ = false; + return false; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + std::string boundary_; + + std::string buf_; + size_t state_ = 0; + size_t is_valid_ = false; + size_t is_done_ = false; + size_t off_ = 0; + MultipartFormData file_; +}; + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; +} + +inline std::string make_multipart_data_boundary() { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::random_device seed_gen; + std::mt19937 engine(seed_gen()); + + std::string result = "--cpp-httplib-multipart-data-"; + + for (auto i = 0; i < 16; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + + return result; +} + +inline std::pair +get_range_offset_and_length(const Request &req, size_t content_length, + size_t index) { + auto r = req.ranges[index]; + + if (r.first == -1 && r.second == -1) { + return std::make_pair(0, content_length); + } + + auto slen = static_cast(content_length); + + if (r.first == -1) { + r.first = slen - r.second; + r.second = slen - 1; + } + + if (r.second == -1) { r.second = slen - 1; } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline std::string make_content_range_header_field(size_t offset, size_t length, + size_t content_length) { + std::string field = "bytes "; + field += std::to_string(offset); + field += "-"; + field += std::to_string(offset + length - 1); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + SToken stoken, CToken ctoken, + Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offsets = get_range_offset_and_length(req, res.body.size(), i); + auto offset = offsets.first; + auto length = offsets.second; + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset, length, res.body.size())); + ctoken("\r\n"); + ctoken("\r\n"); + if (!content(offset, length)) { return false; } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--\r\n"); + + return true; +} + +inline std::string make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + std::string data; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data += token; }, + [&](const char *token) { data += token; }, + [&](size_t offset, size_t length) { + data += res.body.substr(offset, length); + return true; + }); + + return data; +} + +inline size_t +get_multipart_ranges_data_length(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { data_length += token.size(); }, + [&](const char *token) { data_length += strlen(token); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +inline bool write_multipart_ranges_data(Stream &strm, const Request &req, + Response &res, + const std::string &boundary, + const std::string &content_type) { + return process_multipart_ranges_data( + req, res, boundary, content_type, + [&](const std::string &token) { strm.write(token); }, + [&](const char *token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider, offset, length) >= 0; + }); +} + +inline std::pair +get_range_offset_and_length(const Request &req, const Response &res, + size_t index) { + auto r = req.ranges[index]; + + if (r.second == -1) { + r.second = static_cast(res.content_length) - 1; + } + + return std::make_pair(r.first, r.second - r.first + 1); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const char *s) { + auto p = s; + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +template +inline std::string message_digest(const std::string &s, Init init, + Update update, Final final, + size_t digest_length) { + using namespace std; + + std::vector md(digest_length, 0); + CTX ctx; + init(&ctx); + update(&ctx, s.data(), s.size()); + final(md.data(), &ctx); + + stringstream ss; + for (auto c : md) { + ss << setfill('0') << setw(2) << hex << (unsigned int)c; + } + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, MD5_Init, MD5_Update, MD5_Final, + MD5_DIGEST_LENGTH); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, SHA256_Init, SHA256_Update, SHA256_Final, + SHA256_DIGEST_LENGTH); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, SHA512_Init, SHA512_Update, SHA512_Final, + SHA512_DIGEST_LENGTH); +} +#endif + +template +inline ssize_t handle_EINTR(T fn) { + ssize_t res = false; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { + continue; + } + break; + } + return res; +} + +#define HANDLE_EINTR(method, ...) (handle_EINTR([&]() { return method(__VA_ARGS__); })) + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + WSAStartup(0x0002, &wsaData); + } + + ~WSInit() { WSACleanup(); } +}; + +static WSInit wsinit_; +#endif + +} // namespace detail + +// Header utilities +inline std::pair make_range_header(Ranges ranges) { + std::string field = "bytes="; + auto i = 0; + for (auto r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", field); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + using namespace std; + + string nc; + { + stringstream ss; + ss << setfill('0') << setw(8) << hex << cnonce_count; + nc = ss.str(); + } + + auto qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else { + qop = "auth"; + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + string response; + { + auto H = algo == "SHA-256" + ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + + auto field = "Digest username=\"hello\", realm=\"" + auth.at("realm") + + "\", nonce=\"" + auth.at("nonce") + "\", uri=\"" + req.path + + "\", algorithm=" + algo + ", qop=" + qop + ", nc=\"" + nc + + "\", cnonce=\"" + cnonce + "\", response=\"" + response + "\""; + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const httplib::Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + auto m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +// https://stackoverflow.com/questions/440133/how-do-i-create-a-random-alpha-numeric-string-in-c/440240#answer-440240 +inline std::string random_string(size_t length) { + auto randchar = []() -> char { + const char charset[] = "0123456789" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz"; + const size_t max_index = (sizeof(charset) - 1); + return charset[static_cast(rand()) % max_index]; + }; + std::string str(length, 0); + std::generate_n(str.begin(), length, randchar); + return str; +} + +// Request implementation +inline bool Request::has_header(const char *key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const char *key, size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Request::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const char *key, const char *val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Request::set_header(const char *key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const char *key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const char *key, size_t id) const { + auto it = params.find(key); + std::advance(it, static_cast(id)); + if (it != params.end()) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const char *key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.find("multipart/form-data"); +} + +inline bool Request::has_file(const char *key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const char *key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); +} + +// Response implementation +inline bool Response::has_header(const char *key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const char *key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Response::get_header_value_count(const char *key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const char *key, const char *val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_header(const char *key, const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val.c_str())) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const char *url, int status) { + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= status && status < 400) { + this->status = status; + } else { + this->status = 302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const char *content_type) { + body.assign(s, n); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(std::string s, const char *content_type) { + body = std::move(s); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, + std::function provider, + std::function resource_releaser) { + assert(in_length > 0); + content_length = in_length; + content_provider = [provider](size_t offset, size_t length, DataSink &sink) { + provider(offset, length, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +inline void Response::set_chunked_content_provider( + std::function provider, + std::function resource_releaser) { + content_length = 0; + content_provider = [provider](size_t offset, size_t, DataSink &sink) { + provider(offset, sink); + }; + content_provider_resource_releaser = resource_releaser; +} + +// Rstream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +template +inline ssize_t Stream::write_format(const char *fmt, const Args &... args) { + std::array buf; + +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto sn = _snprintf_s(buf, bufsiz, buf.size() - 1, fmt, args...); +#else + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); +#endif + if (sn <= 0) { return sn; } + + auto n = static_cast(sn); + + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); + + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); +#if defined(_MSC_VER) && _MSC_VER < 1900 + n = static_cast(_snprintf_s(&glowable_buf[0], glowable_buf.size(), + glowable_buf.size() - 1, fmt, + args...)); +#else + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); +#endif + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); + } +} + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SocketStream::~SocketStream() {} + +inline bool SocketStream::is_readable() const { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SocketStream::is_writable() const { + return select_write(sock_, 0, 0) > 0; +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { + if (!is_readable()) { return -1; } + +#ifdef _WIN32 + if (size > static_cast(std::numeric_limits::max())) { + return -1; + } + return recv(sock_, ptr, static_cast(size), 0); +#else + return HANDLE_EINTR(recv, sock_, ptr, size, 0); +#endif +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { return -1; } + +#ifdef _WIN32 + if (size > static_cast(std::numeric_limits::max())) { + return -1; + } + return send(sock_, ptr, static_cast(size), 0); +#else + return send(sock_, ptr, size, 0); +#endif +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1900 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : keep_alive_max_count_(CPPHTTPLIB_KEEPALIVE_MAX_COUNT), + read_timeout_sec_(CPPHTTPLIB_READ_TIMEOUT_SECOND), + read_timeout_usec_(CPPHTTPLIB_READ_TIMEOUT_USECOND), + payload_max_length_(CPPHTTPLIB_PAYLOAD_MAX_LENGTH), is_running_(false), + svr_sock_(INVALID_SOCKET) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif + new_task_queue = [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }; +} + +inline Server::~Server() {} + +inline Server &Server::Get(const char *pattern, Handler handler) { + get_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Post(const char *pattern, Handler handler) { + post_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Post(const char *pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Put(const char *pattern, Handler handler) { + put_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Put(const char *pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Patch(const char *pattern, Handler handler) { + patch_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Patch(const char *pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Delete(const char *pattern, Handler handler) { + delete_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Delete(const char *pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.push_back( + std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline Server &Server::Options(const char *pattern, Handler handler) { + options_handlers_.push_back(std::make_pair(std::regex(pattern), handler)); + return *this; +} + +inline bool Server::set_base_dir(const char *dir, const char *mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const char *mount_point, const char *dir) { + if (detail::is_dir(dir)) { + std::string mnt = mount_point ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.emplace_back(mnt, dir); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const char *mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->first == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline void Server::set_file_extension_and_mimetype_mapping(const char *ext, + const char *mime) { + file_extension_and_mimetype_map_[ext] = mime; +} + +inline void Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); +} + +inline void Server::set_error_handler(Handler handler) { + error_handler_ = std::move(handler); +} + +inline void Server::set_logger(Logger logger) { logger_ = std::move(logger); } + +inline void +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); +} + +inline void Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; +} + +inline bool Server::bind_to_port(const char *host, int port, int socket_flags) { + if (bind_internal(host, port, socket_flags) < 0) return false; + return true; +} +inline int Server::bind_to_any_port(const char *host, int socket_flags) { + return bind_internal(host, 0, socket_flags); +} + +inline bool Server::listen_after_bind() { return listen_internal(); } + +inline bool Server::listen(const char *host, int port, int socket_flags) { + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline bool Server::parse_request_line(const char *s, Request &req) { + const static std::regex re( + "(GET|HEAD|POST|PUT|DELETE|CONNECT|OPTIONS|TRACE|PATCH|PRI) " + "(([^?]+)(?:\\?(.*?))?) (HTTP/1\\.[01])\r\n"); + + std::cmatch m; + if (std::regex_match(s, m, re)) { + req.version = std::string(m[5]); + req.method = std::string(m[1]); + req.target = std::string(m[2]); + req.path = detail::decode_url(m[3], false); + + // Parse query text + auto len = std::distance(m[4].first, m[4].second); + if (len > 0) { detail::parse_query_text(m[4], req.params); } + + return true; + } + + return false; +} + +inline bool Server::write_response(Stream &strm, bool last_connection, + const Request &req, Response &res) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_) { error_handler_(req, res); } + + detail::BufferStream bstrm; + + // Response line + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + detail::status_message(res.status))) { + return false; + } + + // Headers + if (last_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } + + if (!last_connection && req.get_header_value("Connection") == "Keep-Alive") { + res.set_header("Connection", "Keep-Alive"); + } + + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length > 0)) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } + + std::string content_type; + std::string boundary; + + if (req.ranges.size() > 1) { + boundary = detail::make_multipart_data_boundary(); + + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + res.headers.emplace("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + if (res.body.empty()) { + if (res.content_length > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.content_length); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length(req, res, boundary, + content_type); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider) { + res.set_header("Transfer-Encoding", "chunked"); + } else { + res.set_header("Content-Length", "0"); + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.body.size(), 0); + auto offset = offsets.first; + auto length = offsets.second; + auto content_range = detail::make_content_range_header_field( + offset, length, res.body.size()); + res.set_header("Content-Range", content_range); + res.body = res.body.substr(offset, length); + } else { + res.body = + detail::make_multipart_ranges_data(req, res, boundary, content_type); + } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + const auto &encodings = req.get_header_value("Accept-Encoding"); + if (encodings.find("gzip") != std::string::npos && + detail::can_compress(res.get_header_value("Content-Type"))) { + if (detail::compress(res.body)) { + res.set_header("Content-Encoding", "gzip"); + } + } +#endif + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } + + if (!detail::write_headers(bstrm, res, Headers())) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!strm.write(res.body)) { return false; } + } else if (res.content_provider) { + if (!write_content_with_provider(strm, req, res, boundary, + content_type)) { + return false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + if (res.content_length) { + if (req.ranges.empty()) { + if (detail::write_content(strm, res.content_provider, 0, + res.content_length) < 0) { + return false; + } + } else if (req.ranges.size() == 1) { + auto offsets = + detail::get_range_offset_and_length(req, res.content_length, 0); + auto offset = offsets.first; + auto length = offsets.second; + if (detail::write_content(strm, res.content_provider, offset, length) < + 0) { + return false; + } + } else { + if (!detail::write_multipart_ranges_data(strm, req, res, boundary, + content_type)) { + return false; + } + } + } else { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + if (detail::write_content_chunked(strm, res.content_provider, + is_shutting_down) < 0) { + return false; + } + } + return true; +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, receiver, multipart_header, + multipart_receiver); +} + +inline bool Server::read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader mulitpart_header, + ContentReceiver multipart_receiver) { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiver out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = 400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n) { + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + mulitpart_header); + }; + } else { + out = receiver; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, + Progress(), out)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = 400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(Request &req, Response &res, + bool head) { + for (const auto &kv : base_dirs_) { + const auto &mount_point = kv.first; + const auto &base_dir = kv.second; + + // Prefix match + if (!req.path.find(mount_point)) { + std::string sub_path = "/" + req.path.substr(mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + if (detail::is_file(path)) { + detail::read_file(path, res.body); + auto type = + detail::find_content_type(path, file_extension_and_mimetype_map_); + if (type) { res.set_header("Content-Type", type); } + res.status = 200; + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + return true; + } + } + } + } + return false; +} + +inline socket_t Server::create_server_socket(const char *host, int port, + int socket_flags) const { + return detail::create_socket( + host, port, + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, 5)) { // Listen through 5 channels + return false; + } + return true; + }, + socket_flags); +} + +inline int Server::bind_internal(const char *host, int port, int socket_flags) { + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + + { + std::unique_ptr task_queue(new_task_queue()); + + for (;;) { + if (svr_sock_ == INVALID_SOCKET) { + // The server socket was closed by 'stop' method. + break; + } + + auto val = detail::select_read(svr_sock_, 0, 100000); + + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } + + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + +#if __cplusplus > 201703L + task_queue->enqueue([=, this]() { process_and_close_socket(sock); }); +#else + task_queue->enqueue([=]() { process_and_close_socket(sock); }); +#endif + } + + task_queue->shutdown(); + } + + is_running_ = false; + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + // File handler + bool is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, receiver, + nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + header, receiver); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, reader, post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, reader, put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, reader, patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, reader, delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = 400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + Handlers &handlers) { + + try { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res); + return true; + } + } + } catch (const std::exception &ex) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", ex.what()); + } catch (...) { + res.status = 500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + return false; +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + HandlersForContentReader &handlers) { + for (const auto &x : handlers) { + const auto &pattern = x.first; + const auto &handler = x.second; + + if (std::regex_match(req.path, req.matches, pattern)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, bool last_connection, + bool &connection_close, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + Response res; + + res.version = "HTTP/1.1"; + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = 414; + return write_response(strm, last_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = 400; + return write_response(strm, last_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_close = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_close = true; + } + + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + // TODO: error + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + auto status = 100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case 100: + case 417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + detail::status_message(status)); + break; + default: return write_response(strm, last_connection, req, res); + } + } + + // Rounting + if (routing(req, res, strm)) { + if (res.status == -1) { res.status = req.ranges.empty() ? 200 : 206; } + } else { + if (res.status == -1) { res.status = 404; } + } + + return write_response(strm, last_connection, req, res); +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + [this](Stream &strm, bool last_connection, bool &connection_close) { + return process_request(strm, last_connection, connection_close, + nullptr); + }); +} + +// HTTP client implementation +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : sock_(INVALID_SOCKET), host_(host), port_(port), + host_and_port_(host_ + ":" + std::to_string(port_)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline Client::~Client() {} + +inline bool Client::is_valid() const { return true; } + +inline socket_t Client::create_client_socket() const { + if (!proxy_host_.empty()) { + return detail::create_client_socket(proxy_host_.c_str(), proxy_port_, + timeout_sec_, interface_); + } + return detail::create_client_socket(host_.c_str(), port_, timeout_sec_, + interface_); +} + +inline bool Client::read_response_line(Stream &strm, Response &res) { + std::array buf; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + + const static std::regex re("(HTTP/1\\.[01]) (\\d+?) .*\r\n"); + + std::cmatch m; + if (std::regex_match(line_reader.ptr(), m, re)) { + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + } + + return true; +} + +inline bool Client::send(const Request &req, Response &res) { + sock_ = create_client_socket(); + if (sock_ == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + bool error; + if (!connect(sock_, res, error)) { return error; } + } +#endif + + return process_and_close_socket( + sock_, 1, + [&](Stream &strm, bool last_connection, bool &connection_close) { + return handle_request(strm, req, res, last_connection, + connection_close); + }); +} + +inline bool Client::send(const std::vector &requests, + std::vector &responses) { + size_t i = 0; + while (i < requests.size()) { + sock_ = create_client_socket(); + if (sock_ == INVALID_SOCKET) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && !proxy_host_.empty()) { + Response res; + bool error; + if (!connect(sock_, res, error)) { return false; } + } +#endif + + if (!process_and_close_socket(sock_, requests.size() - i, + [&](Stream &strm, bool last_connection, + bool &connection_close) -> bool { + auto &req = requests[i++]; + auto res = Response(); + auto ret = handle_request(strm, req, res, + last_connection, + connection_close); + if (ret) { + responses.emplace_back(std::move(res)); + } + return ret; + })) { + return false; + } + } + + return true; +} + +inline bool Client::handle_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + if (req.path.empty()) { return false; } + + bool ret; + + if (!is_ssl() && !proxy_host_.empty()) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, last_connection, connection_close); + } else { + ret = process_request(strm, req, res, last_connection, connection_close); + } + + if (!ret) { return false; } + + if (300 < res.status && res.status < 400 && follow_location_) { + ret = redirect(req, res); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (res.status == 401 || res.status == 407) { + auto is_proxy = res.status == 407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + auto key = is_proxy ? "Proxy-Authorization" : "WWW-Authorization"; + new_req.headers.erase(key); + new_req.headers.insert(make_digest_authentication_header( + req, auth, 1, random_string(10), username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline bool Client::connect(socket_t sock, Response &res, bool &error) { + error = true; + Response res2; + + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, bool &connection_close) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, res2, false, connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + + if (res2.status == 407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (parse_www_authenticate(res2, auth, true)) { + Response res3; + if (!detail::process_socket( + true, sock, 1, read_timeout_sec_, read_timeout_usec_, + [&](Stream &strm, bool /*last_connection*/, + bool &connection_close) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(make_digest_authentication_header( + req3, auth, 1, random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, res3, false, + connection_close); + })) { + detail::close_socket(sock); + error = false; + return false; + } + } + } else { + res = res2; + return false; + } + } + + return true; +} +#endif + +inline bool Client::redirect(const Request &req, Response &res) { + if (req.redirect_count == 0) { return false; } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"(^(?:(https?):)?(?://([^:/?#]*)(?::(\d+))?)?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, next_path); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); +#else + return false; +#endif + } else { + Client cli(next_host.c_str(), next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, next_path); + } + } +} + +inline bool Client::write_request(Stream &strm, const Request &req, + bool last_connection) { + detail::BufferStream bstrm; + + // Request line + const auto &path = detail::encode_url(req.path); + + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + // Additonal headers + Headers headers; + if (last_connection) { headers.emplace("Connection", "close"); } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } else { + if (port_ == 80) { + headers.emplace("Host", host_); + } else { + headers.emplace("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { headers.emplace("Accept", "*/*"); } + + if (!req.has_header("User-Agent")) { + headers.emplace("User-Agent", "cpp-httplib/0.5"); + } + + if (req.body.empty()) { + if (req.content_provider) { + auto length = std::to_string(req.content_length); + headers.emplace("Content-Length", length); + } else { + headers.emplace("Content-Length", "0"); + } + } else { + if (!req.has_header("Content-Type")) { + headers.emplace("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + headers.emplace("Content-Length", length); + } + } + + if (!basic_auth_username_.empty() && !basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + + detail::write_headers(bstrm, req, headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + strm.write(data.data(), data.size()); + + // Body + if (req.body.empty()) { + if (req.content_provider) { + size_t offset = 0; + size_t end_offset = req.content_length; + + DataSink data_sink; + data_sink.write = [&](const char *d, size_t l) { + auto written_length = strm.write(d, l); + offset += static_cast(written_length); + }; + data_sink.is_writable = [&](void) { return strm.is_writable(); }; + + while (offset < end_offset) { + req.content_provider(offset, end_offset - offset, data_sink); + } + } + } else { + strm.write(req.body); + } + + return true; +} + +inline std::shared_ptr Client::send_with_content_provider( + const char *method, const char *path, const Headers &headers, + const std::string &body, size_t content_length, + ContentProvider content_provider, const char *content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + if (content_provider) { + size_t offset = 0; + + DataSink data_sink; + data_sink.write = [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + offset += data_len; + }; + data_sink.is_writable = [&](void) { return true; }; + + while (offset < content_length) { + content_provider(offset, content_length - offset, data_sink); + } + } else { + req.body = body; + } + + if (!detail::compress(req.body)) { return nullptr; } + req.headers.emplace("Content-Encoding", "gzip"); + } else +#endif + { + if (content_provider) { + req.content_length = content_length; + req.content_provider = content_provider; + } else { + req.body = body; + } + } + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline bool Client::process_request(Stream &strm, const Request &req, + Response &res, bool last_connection, + bool &connection_close) { + // Send request + if (!write_request(strm, req, last_connection)) { return false; } + + // Receive response and headers + if (!read_response_line(strm, res) || + !detail::read_headers(strm, res.headers)) { + return false; + } + + if (res.get_header_value("Connection") == "close" || + res.version == "HTTP/1.0") { + connection_close = true; + } + + if (req.response_handler) { + if (!req.response_handler(res)) { return false; } + } + + // Body + if (req.method != "HEAD" && req.method != "CONNECT") { + auto out = + req.content_receiver + ? static_cast([&](const char *buf, size_t n) { + return req.content_receiver(buf, n); + }) + : static_cast([&](const char *buf, size_t n) { + if (res.body.size() + n > res.body.max_size()) { return false; } + res.body.append(buf, n); + return true; + }); + + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, req.progress, out)) { + return false; + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline bool Client::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + request_count = (std::min)(request_count, keep_alive_max_count_); + return detail::process_and_close_socket(true, sock, request_count, + read_timeout_sec_, read_timeout_usec_, + callback); +} + +inline bool Client::is_ssl() const { return false; } + +inline std::shared_ptr Client::Get(const char *path) { + return Get(path, Headers(), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline std::shared_ptr +Client::Get(const char *path, const Headers &headers, Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), content_receiver, + Progress()); +} + +inline std::shared_ptr Client::Get(const char *path, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = std::move(content_receiver); + req.progress = std::move(progress); + + auto res = std::make_shared(); + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Head(const char *path) { + return Head(path, Headers()); +} + +inline std::shared_ptr Client::Head(const char *path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Post(const char *path) { + return Post(path, std::string(), nullptr); +} + +inline std::shared_ptr Client::Post(const char *path, + const std::string &body, + const char *content_type) { + return Post(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Post(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("POST", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Post(const char *path, + const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline std::shared_ptr Client::Post(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Post(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("POST", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr +Client::Post(const char *path, const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline std::shared_ptr +Client::Post(const char *path, const Headers &headers, + const MultipartFormDataItems &items) { + auto boundary = detail::make_multipart_data_boundary(); + + std::string body; + + for (const auto &item : items) { + body += "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + body += item.content + "\r\n"; + } + + body += "--" + boundary + "--\r\n"; + + std::string content_type = "multipart/form-data; boundary=" + boundary; + return Post(path, headers, body, content_type.c_str()); +} + +inline std::shared_ptr Client::Put(const char *path) { + return Put(path, std::string(), nullptr); +} + +inline std::shared_ptr Client::Put(const char *path, + const std::string &body, + const char *content_type) { + return Put(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PUT", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Put(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Put(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PUT", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr Client::Put(const char *path, + const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline std::shared_ptr +Client::Put(const char *path, const Headers &headers, const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline std::shared_ptr Client::Patch(const char *path, + const std::string &body, + const char *content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Patch(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + return send_with_content_provider("PATCH", path, headers, body, 0, nullptr, + content_type); +} + +inline std::shared_ptr Client::Patch(const char *path, + size_t content_length, + ContentProvider content_provider, + const char *content_type) { + return Patch(path, Headers(), content_length, content_provider, content_type); +} + +inline std::shared_ptr +Client::Patch(const char *path, const Headers &headers, size_t content_length, + ContentProvider content_provider, const char *content_type) { + return send_with_content_provider("PATCH", path, headers, std::string(), + content_length, content_provider, + content_type); +} + +inline std::shared_ptr Client::Delete(const char *path) { + return Delete(path, Headers(), std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, + const std::string &body, + const char *content_type) { + return Delete(path, Headers(), body, content_type); +} + +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers) { + return Delete(path, headers, std::string(), nullptr); +} + +inline std::shared_ptr Client::Delete(const char *path, + const Headers &headers, + const std::string &body, + const char *content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (content_type) { req.headers.emplace("Content-Type", content_type); } + req.body = body; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline std::shared_ptr Client::Options(const char *path) { + return Options(path, Headers()); +} + +inline std::shared_ptr Client::Options(const char *path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.path = path; + req.headers = headers; + + auto res = std::make_shared(); + + return send(req, *res) ? res : nullptr; +} + +inline void Client::stop() { + if (sock_ != INVALID_SOCKET) { + std::atomic sock(sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline void Client::set_timeout_sec(time_t timeout_sec) { + timeout_sec_ = timeout_sec; +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void Client::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; +} + +inline void Client::set_basic_auth(const char *username, const char *password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const char *username, + const char *password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void Client::set_follow_location(bool on) { follow_location_ = on; } + +inline void Client::set_compress(bool on) { compress_ = on; } + +inline void Client::set_interface(const char *intf) { interface_ = intf; } + +inline void Client::set_proxy(const char *host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void Client::set_proxy_basic_auth(const char *username, + const char *password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const char *username, + const char *password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} +#endif + +inline void Client::set_logger(Logger logger) { logger_ = std::move(logger); } + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline bool process_and_close_socket_ssl( + bool is_client_request, socket_t sock, size_t keep_alive_max_count, + time_t read_timeout_sec, time_t read_timeout_usec, SSL_CTX *ctx, + std::mutex &ctx_mutex, U SSL_connect_or_accept, V setup, T callback) { + assert(keep_alive_max_count > 0); + + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (!ssl) { + close_socket(sock); + return false; + } + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl)) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + return false; + } + + auto ret = false; + + if (SSL_connect_or_accept(ssl) == 1) { + if (keep_alive_max_count > 1) { + auto count = keep_alive_max_count; + while (count > 0 && + (is_client_request || + select_read(sock, CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND, + CPPHTTPLIB_KEEPALIVE_TIMEOUT_USECOND) > 0)) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto last_connection = count == 1; + auto connection_close = false; + + ret = callback(ssl, strm, last_connection, connection_close); + if (!ret || connection_close) { break; } + + count--; + } + } else { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec); + auto dummy_connection_close = false; + ret = callback(ssl, strm, true, dummy_connection_close); + } + } + + if (ret) { + SSL_shutdown(ssl); // shutdown only if not already closed by remote + } + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + + close_socket(sock); + + return ret; +} + +#if OPENSSL_VERSION_NUMBER < 0x10100000L +static std::shared_ptr> openSSL_locks_; + +class SSLThreadLocks { +public: + SSLThreadLocks() { + openSSL_locks_ = + std::make_shared>(CRYPTO_num_locks()); + CRYPTO_set_locking_callback(locking_callback); + } + + ~SSLThreadLocks() { CRYPTO_set_locking_callback(nullptr); } + +private: + static void locking_callback(int mode, int type, const char * /*file*/, + int /*line*/) { + auto &lk = (*openSSL_locks_)[static_cast(type)]; + if (mode & CRYPTO_LOCK) { + lk.lock(); + } else { + lk.unlock(); + } + } +}; + +#endif + +class SSLInit { +public: + SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + SSL_load_error_strings(); + SSL_library_init(); +#else + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); +#endif + } + + ~SSLInit() { +#if OPENSSL_VERSION_NUMBER < 0x1010001fL + ERR_free_strings(); +#endif + } + +private: +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSLThreadLocks thread_init_; +#endif +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec) {} + +inline SSLSocketStream::~SSLSocketStream() {} + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return detail::select_write(sock_, 0, 0) > 0; +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0 || + select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { return SSL_write(ssl_, ptr, static_cast(size)); } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + // auto ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); + // SSL_CTX_set_tmp_ecdh(ctx_, ecdh); + // EC_KEY_free(ecdh); + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + // if (client_ca_cert_file_path) { + // auto list = SSL_load_client_CA_file(client_ca_cert_file_path); + // SSL_CTX_set_client_CA_list(ctx_, list); + // } + + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(SSLv23_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_ALL | SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, // SSL_VERIFY_CLIENT_ONCE, + nullptr); + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + return detail::process_and_close_socket_ssl( + false, sock, keep_alive_max_count_, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, SSL_accept, [](SSL * /*ssl*/) { return true; }, + [this](SSL *ssl, Stream &strm, bool last_connection, + bool &connection_close) { + return process_request(strm, last_connection, connection_close, + [&](Request &req) { req.ssl = ssl; }); + }); +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : Client(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key) + : Client(host, port) { + ctx_ = SSL_CTX_new(SSLv23_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(std::string(b, e)); + }); + if (client_cert != nullptr && client_key != nullptr) { + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_path(const char *ca_cert_file_path, + const char *ca_cert_dir_path) { + if (ca_cert_file_path) { ca_cert_file_path_ = ca_cert_file_path; } + if (ca_cert_dir_path) { ca_cert_dir_path_ = ca_cert_dir_path; } +} + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { ca_cert_store_ = ca_cert_store; } +} + +inline void SSLClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::process_and_close_socket( + socket_t sock, size_t request_count, + std::function + callback) { + + request_count = std::min(request_count, keep_alive_max_count_); + + return is_valid() && + detail::process_and_close_socket_ssl( + true, sock, request_count, read_timeout_sec_, read_timeout_usec_, + ctx_, ctx_mutex_, + [&](SSL *ssl) { + if (ca_cert_file_path_.empty() && ca_cert_store_ == nullptr) { + SSL_CTX_set_verify(ctx_, SSL_VERIFY_NONE, nullptr); + } else if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations( + ctx_, ca_cert_file_path_.c_str(), nullptr)) { + return false; + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } else if (ca_cert_store_ != nullptr) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store_) { + SSL_CTX_set_cert_store(ctx_, ca_cert_store_); + } + SSL_CTX_set_verify(ctx_, SSL_VERIFY_PEER, nullptr); + } + + if (SSL_connect(ssl) != 1) { return false; } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl); + + if (verify_result_ != X509_V_OK) { return false; } + + auto server_cert = SSL_get_peer_certificate(ssl); + + if (server_cert == nullptr) { return false; } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl) { + SSL_set_tlsext_host_name(ssl, host_.c_str()); + return true; + }, + [&](SSL * /*ssl*/, Stream &strm, bool last_connection, + bool &connection_close) { + return callback(strm, last_connection, connection_close); + }); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6; + struct in_addr addr; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_mached = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (auto i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = (const char *)ASN1_STRING_get0_data(val->d.ia5); + auto name_len = (size_t)ASN1_STRING_length(val->d.ia5); + + if (strlen(name) == name_len) { + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_mached = true; + } + break; + } + } + } + } + + if (dsn_matched || ip_mached) { ret = true; } + } + + GENERAL_NAMES_free((STACK_OF(GENERAL_NAME) *)alt_names); + + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(std::string(b, e)); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +namespace url { + +struct Options { + // TODO: support more options... + bool follow_location = false; + std::string client_cert_path; + std::string client_key_path; + + std::string ca_cert_file_path; + std::string ca_cert_dir_path; + bool server_certificate_verification = false; +}; + +inline std::shared_ptr Get(const char *url, Options &options) { + const static std::regex re( + R"(^(https?)://([^:/?#]+)(?::(\d+))?([^?#]*(?:\?[^#]*)?)(?:#.*)?)"); + + std::cmatch m; + if (!std::regex_match(url, m, re)) { return nullptr; } + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + auto port_str = m[3].str(); + auto next_path = m[4].str(); + + auto next_port = !port_str.empty() ? std::stoi(port_str) + : (next_scheme == "https" ? 443 : 80); + + if (next_path.empty()) { next_path = "/"; } + + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host.c_str(), next_port, options.client_cert_path, + options.client_key_path); + cli.set_follow_location(options.follow_location); + cli.set_ca_cert_path(options.ca_cert_file_path.c_str(), + options.ca_cert_dir_path.c_str()); + cli.enable_server_certificate_verification( + options.server_certificate_verification); + return cli.Get(next_path.c_str()); +#else + return nullptr; +#endif + } else { + Client cli(next_host.c_str(), next_port, options.client_cert_path, + options.client_key_path); + cli.set_follow_location(options.follow_location); + return cli.Get(next_path.c_str()); + } +} + +inline std::shared_ptr Get(const char *url) { + Options options; + return Get(url, options); +} + +} // namespace url + +namespace detail { + +#undef HANDLE_EINTR + +} // namespace detail + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#endif // CPPHTTPLIB_HTTPLIB_H + diff --git a/src/VersionChecker.cpp b/src/VersionChecker.cpp index 64d2d0e4d..6a0bb0827 100644 --- a/src/VersionChecker.cpp +++ b/src/VersionChecker.cpp @@ -1,187 +1,24 @@ -// based off of -// async_client.cpp -// ~~~~~~~~~~~~~~~~ -// -// Copyright (c) 2003-2012 Christopher M. Kohlhoff (chris at kohlhoff dot com) -// -// Distributed under the Boost Software License, Version 1.0. (See accompanying -// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) -// - #include "VersionChecker.hpp" #include "SalmonConfig.hpp" - -using boost::asio::ip::tcp; - -VersionChecker::VersionChecker(boost::asio::io_service& io_service, - const std::string& server, - const std::string& path) - : resolver_(io_service), socket_(io_service), deadline_(io_service) { - // Form the request. We specify the "Connection: close" header so that the - // server will close the socket after transmitting the response. This will - // allow us to treat all data up until the EOF as the content. - std::ostream request_stream(&request_); - request_stream << "GET " << path << " HTTP/1.0\r\n"; - request_stream << "Host: " << server << "\r\n"; - request_stream << "Accept: */*\r\n"; - request_stream << "Connection: close\r\n\r\n"; - - - deadline_.expires_from_now(boost::posix_time::seconds(1)); - deadline_.async_wait(boost::bind(&VersionChecker::cancel_upgrade_check, this, - boost::asio::placeholders::error)); - - // Start an asynchronous resolve to translate the server and service names - // into a list of endpoints. - tcp::resolver::query query(server, "http"); - resolver_.async_resolve(query, - boost::bind(&VersionChecker::handle_resolve, this, - boost::asio::placeholders::error, - boost::asio::placeholders::iterator)); -} - -std::string VersionChecker::message() { return messageStream_.str(); } - -void VersionChecker::cancel_upgrade_check( - const boost::system::error_code& err) { - if (err != boost::asio::error::operation_aborted) { - deadline_.cancel(); - messageStream_ - << "Could not resolve upgrade information in the alotted time.\n"; - messageStream_ << "Check for upgrades manually at " - "https://combine-lab.github.io/salmon\n"; - socket_.close(); - } -} - -void VersionChecker::handle_resolve(const boost::system::error_code& err, - tcp::resolver::iterator endpoint_iterator) { - if (!err) { - // Attempt a connection to each endpoint in the list until we - // successfully establish a connection. - boost::asio::async_connect(socket_, endpoint_iterator, - boost::bind(&VersionChecker::handle_connect, - this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_connect(const boost::system::error_code& err) { - if (!err) { - // The connection was successful. Send the request. - boost::asio::async_write(socket_, request_, - boost::bind(&VersionChecker::handle_write_request, - this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_write_request( - const boost::system::error_code& err) { - if (!err) { - // Read the response status line. The response_ streambuf will - // automatically grow to accommodate the entire line. The growth may be - // limited by passing a maximum size to the streambuf constructor. - boost::asio::async_read_until( - socket_, response_, "\r\n", - boost::bind(&VersionChecker::handle_read_status_line, this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_read_status_line( - const boost::system::error_code& err) { - if (!err) { - // Check that response is OK. - std::istream response_stream(&response_); - std::string http_version; - response_stream >> http_version; - unsigned int status_code; - response_stream >> status_code; - std::string status_message; - std::getline(response_stream, status_message); - if (!response_stream || http_version.substr(0, 5) != "HTTP/") { - deadline_.cancel(); - cancel_upgrade_check(err); - return; - } - if (status_code != 200) { - deadline_.cancel(); - cancel_upgrade_check(err); - return; - } - - // Read the response headers, which are terminated by a blank line. - boost::asio::async_read_until( - socket_, response_, "\r\n\r\n", - boost::bind(&VersionChecker::handle_read_headers, this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_read_headers(const boost::system::error_code& err) { - if (!err) { - deadline_.cancel(); - // Process the response headers. - std::istream response_stream(&response_); - std::string header; - while (std::getline(response_stream, header) && header != "\r") { - } - - // Write whatever content we already have to output. - if (response_.size() > 0) - messageStream_ << &response_; - - // Start reading remaining data until EOF. - boost::asio::async_read( - socket_, response_, boost::asio::transfer_at_least(1), - boost::bind(&VersionChecker::handle_read_content, this, - boost::asio::placeholders::error)); - } else { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} - -void VersionChecker::handle_read_content(const boost::system::error_code& err) { - if (!err) { - // Write all of the data that has been read so far. - messageStream_ << &response_; - - // Continue reading remaining data until EOF. - boost::asio::async_read( - socket_, response_, boost::asio::transfer_at_least(1), - boost::bind(&VersionChecker::handle_read_content, this, - boost::asio::placeholders::error)); - } else if (err != boost::asio::error::eof) { - deadline_.cancel(); - cancel_upgrade_check(err); - } -} +#include "httplib.hpp" std::string getVersionMessage() { - std::string baseSite{"combine-lab.github.io"}; - std::string path{"/salmon/version_info/"}; - path += salmon::version; - std::stringstream ss; try { - boost::asio::io_service io_service; - VersionChecker c(io_service, baseSite, path); - io_service.run(); - ss << "Version Info: " << c.message(); + httplib::Client cli("combine-lab.github.io", 80); + std::string path{"/salmon/version_info/"}; + path += salmon::version; + cli.set_timeout_sec(2); // timeouts in 2 seconds + auto res = cli.Get(path.c_str()); + if (res) { // non-null response + if (res->status == 200) { // response OK + ss << "Version Info: " << res->body; + } else { // response something else + ss << "Version Server Response: " << httplib::detail::status_message(res->status) << "\n"; + } + } else { // null response + ss << "Version Info Exception: server did not respond before timeout\n"; + } } catch (std::exception& e) { ss << "Version Info Exception: " << e.what() << "\n"; } From 98b8b4b528d8f8bbb7332b906d07d69a67e84428 Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Thu, 30 Apr 2020 14:08:57 -0400 Subject: [PATCH 08/52] naive mode writing --- include/SingleCellProtocols.hpp | 2 +- src/CollapsedCellOptimizer.cpp | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/include/SingleCellProtocols.hpp b/include/SingleCellProtocols.hpp index 53cba31ce..cb46ac37f 100644 --- a/include/SingleCellProtocols.hpp +++ b/include/SingleCellProtocols.hpp @@ -46,7 +46,7 @@ namespace alevin{ }; struct CITESeq : Rule{ - CITESeq(): Rule(16, 12, BarcodeEnd::FIVE, 4294967295){ + CITESeq(): Rule(16, 10, BarcodeEnd::FIVE, 4294967295){ featureLength = 15; featureStart = 10; } diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index bbdefc781..29c1d7a06 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -696,12 +696,16 @@ void optimizeCell(std::vector& trueBarcodes, } else { // doing per eqclass level naive deduplication for (size_t eqId=0; eqId umis; + size_t numFeats = txpGroups[eqId].size(); + if (numFeats > 1) { continue; }; + spp::sparse_hash_set umis; for(auto& it: umiGroups[eqId]) { umis.insert( it.first ); } + totalCount += umis.size(); + geneAlphas[ txpGroups[eqId][0] ] += umis.size(); // filling in the eqclass level deduplicated counts if (verbose) { @@ -709,8 +713,28 @@ void optimizeCell(std::vector& trueBarcodes, } } + std::string emptyString = ""; + // write the abundance for the cell + bool isWriteOk = gzw.writeSparseAbundances( trueBarcodeStr, + emptyString, + emptyString, + 0, + geneAlphas, + tiers, + false, + false ); + + + if( not isWriteOk ){ + jointlog->error("Gzip Writer failed \n" + "Please Report this on github."); + jointlog->flush(); + std::exit(74); + } + // maintaining count for total number of predicted UMI salmon::utils::incLoop(totalDedupCounts, totalCount); + totalExpGeneCounts += totalExpGenes; } if (verbose) { @@ -1072,7 +1096,7 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, std::copy(geneNames.begin(), geneNames.end(), giterator); gFile.close(); - if( not hasWhitelist and not usingHashMode){ + if( not aopt.naiveEqclass and not hasWhitelist and not usingHashMode){ aopt.jointLog->info("Clearing EqMap; Might take some time."); fullEqMap.clear(); From 0622380f0ebeb15eadd3a738dcb1ecd9fc22a921 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 1 May 2020 16:35:39 -0400 Subject: [PATCH 09/52] some cleanup --- include/AlignmentLibrary.hpp | 3 +++ include/httplib.hpp | 5 +++-- src/Salmon.cpp | 2 +- src/VersionChecker.cpp | 6 +++++- 4 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/AlignmentLibrary.hpp b/include/AlignmentLibrary.hpp index 218e7d895..ec738b1c6 100644 --- a/include/AlignmentLibrary.hpp +++ b/include/AlignmentLibrary.hpp @@ -111,6 +111,9 @@ template class AlignmentLibrary { // The transcript file existed, so load up the transcripts double alpha = 0.005; + // we know how many we will have, so reserve the space for + // them. + transcripts_.reserve(header->nref); for (decltype(header->nref) i = 0; i < header->nref; ++i) { transcripts_.emplace_back(i, header->ref[i].name, header->ref[i].len, alpha); diff --git a/include/httplib.hpp b/include/httplib.hpp index 03ef11a03..af4532c50 100644 --- a/include/httplib.hpp +++ b/include/httplib.hpp @@ -1405,7 +1405,7 @@ socket_t create_socket(const char *host, int port, Fn fn, // Get address info struct addrinfo hints; - struct addrinfo *result; + struct addrinfo *result = nullptr; memset(&hints, 0, sizeof(struct addrinfo)); hints.ai_family = AF_UNSPEC; @@ -1416,6 +1416,7 @@ socket_t create_socket(const char *host, int port, Fn fn, auto service = std::to_string(port); if (getaddrinfo(host, service.c_str(), &hints, &result)) { + if (result != nullptr) { freeaddrinfo(result); } return INVALID_SOCKET; } @@ -3757,7 +3758,7 @@ inline Client::Client(const std::string &host, int port, host_and_port_(host_ + ":" + std::to_string(port_)), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} -inline Client::~Client() {} +inline Client::~Client() { stop(); } inline bool Client::is_valid() const { return true; } diff --git a/src/Salmon.cpp b/src/Salmon.cpp index a2ff3ed1f..a7b6d3980 100644 --- a/src/Salmon.cpp +++ b/src/Salmon.cpp @@ -319,4 +319,4 @@ int main(int argc, char* argv[]) { } return 0; -} +} \ No newline at end of file diff --git a/src/VersionChecker.cpp b/src/VersionChecker.cpp index 6a0bb0827..4cc5634a4 100644 --- a/src/VersionChecker.cpp +++ b/src/VersionChecker.cpp @@ -5,7 +5,11 @@ std::string getVersionMessage() { std::stringstream ss; try { - httplib::Client cli("combine-lab.github.io", 80); + // NOTE: getaddrinfo / freeaddrinfo will cause a "memory leak" + // once per address, program pair. This is a known issue + // https://lists.debian.org/debian-glibc/2016/03/msg00243.html. + // If valgrind leads you here, best not to worry about it. + httplib::Client cli("combine-lab.github.io"); std::string path{"/salmon/version_info/"}; path += salmon::version; cli.set_timeout_sec(2); // timeouts in 2 seconds From c63ca84c465e90c7af56058c51340bc8b5ec8e26 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sat, 2 May 2020 01:06:30 -0400 Subject: [PATCH 10/52] replaced jellyfish kmer with pufferfish kmer _almost_ everywhere --- include/SBModel.hpp | 14 +++++++----- src/SBModel.cpp | 32 +++++++++++++++++--------- src/SalmonAlevin.cpp | 6 ++--- src/SalmonQuantify.cpp | 23 +++++++++++-------- src/SalmonQuantifyAlignments.cpp | 24 ++++++++++++-------- src/SalmonUtils.cpp | 39 +++++++++++++++++++++----------- 6 files changed, 86 insertions(+), 52 deletions(-) diff --git a/include/SBModel.hpp b/include/SBModel.hpp index 8ee10c1bf..50ba9af5b 100644 --- a/include/SBModel.hpp +++ b/include/SBModel.hpp @@ -4,13 +4,14 @@ #include #include "UtilityFunctions.hpp" -#include "jellyfish/mer_dna.hpp" +//#include "jellyfish/mer_dna.hpp" //#include "rapmap/Kmer.hpp" +#include "pufferfish/Kmer.hpp" #include #include -using Mer = jellyfish::mer_dna_ns::mer_base_static; -//using Mer = combinelib::kmers::Kmer<32,2>; +//using Mer = jellyfish::mer_dna_ns::mer_base_static; +using SBMer = combinelib::kmers::Kmer<32,4>; class SBModel { public: @@ -31,13 +32,13 @@ class SBModel { } bool addSequence(const char* seqIn, bool revCmp, double weight = 1.0); - bool addSequence(const Mer& mer, double weight); + bool addSequence(const SBMer& sbmer, double weight); Eigen::MatrixXd& counts(); Eigen::MatrixXd& marginals(); double evaluateLog(const char* seqIn); - double evaluateLog(const Mer& mer); + double evaluateLog(const SBMer& sbmer); bool normalize(); @@ -88,7 +89,8 @@ class SBModel { Eigen::MatrixXd _probs; Eigen::MatrixXd _marginals; - Mer _mer; + //Mer _mer; + SBMer _sbmer; std::vector _order; std::vector _shifts; std::vector _widths; diff --git a/src/SBModel.cpp b/src/SBModel.cpp index 87ee205bc..f635c6ee8 100644 --- a/src/SBModel.cpp +++ b/src/SBModel.cpp @@ -1,4 +1,5 @@ #include "SBModel.hpp" +#include "jellyfish/mer_dna.hpp" #include #include @@ -61,7 +62,8 @@ SBModel::SBModel() : _trained(false) { } // Set k equal to the size of the contexts we'll parse. - _mer.k(_contextLength); + //_mer.k(_contextLength); + _sbmer.k(_contextLength); // To hold all probabilities the matrix must be 4^{max_order + 1} by // context-length @@ -113,21 +115,26 @@ bool SBModel::writeBinary(boost::iostreams::filtering_ostream& out) const { double SBModel::evaluateLog(const char* seqIn) { double p = 0; - Mer mer; - mer.from_chars(seqIn); + //Mer mer; + //mer.from_chars(seqIn); + SBMer sbmer; + sbmer.fromChars(seqIn); for (int32_t i = 0; i < _contextLength; ++i) { - uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + //uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + uint64_t idx = sbmer.get_bits(_shifts[i], _widths[i]); p += _probs(idx, i); } return p; } -double SBModel::evaluateLog(const Mer& mer) { +double SBModel::evaluateLog(const SBMer& sbmer) { double p = 0; for (int32_t i = 0; i < _contextLength; ++i) { - uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + //uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + uint64_t idx = sbmer.get_bits(_shifts[i], _widths[i]); + p += _probs(idx, i); } return p; @@ -183,16 +190,19 @@ void SBModel::dumpConditionalProbabilities(std::ostream& os) { } bool SBModel::addSequence(const char* seqIn, bool revCmp, double weight) { - _mer.from_chars(seqIn); + //_mer.from_chars(seqIn); + bool ok = _sbmer.fromChars(seqIn); if (revCmp) { - _mer.reverse_complement(); + //_mer.reverse_complement(); + _sbmer.rc(); } - return addSequence(_mer, weight); + return addSequence(_sbmer, weight); } -bool SBModel::addSequence(const Mer& mer, double weight) { +bool SBModel::addSequence(const SBMer& sbmer, double weight) { for (int32_t i = 0; i < _contextLength; ++i) { - uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + //uint64_t idx = mer.get_bits(_shifts[i], _widths[i]); + uint64_t idx = sbmer.get_bits(_shifts[i], _widths[i]); _probs(idx, i) += weight; } return true; diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 88a2b3bd0..f7cb1de83 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -49,7 +49,7 @@ // Jellyfish 2 include -#include "jellyfish/mer_dna.hpp" +// #include "jellyfish/mer_dna.hpp" // Boost Includes #include @@ -418,8 +418,8 @@ void processReadsQuasi( observedBiasParams .seqBiasModelRC; // readExp.readBias(salmon::utils::Direction::REVERSE_COMPLEMENT); // k-mers for sequence bias context - Mer leftMer; - Mer rightMer; + // Mer leftMer; + // Mer rightMer; //auto expectedLibType = rl.format(); diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 16fe0f252..19eb2b4bb 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -791,8 +791,11 @@ void processReads( .seqBiasModelRC; // readExp.readBias(salmon::utils::Direction::REVERSE_COMPLEMENT); // k-mers for sequence bias context - Mer leftMer; - Mer rightMer; + //Mer leftMer; + //Mer rightMer; + SBMer leftMer; + SBMer rightMer; + uint64_t firstTimestepOfRound = fmCalc.getCurrentTimestep(); size_t minK = qidx->k(); @@ -1300,14 +1303,14 @@ void processReads( int32_t fwPos = (h.fwd) ? startPos1 : startPos2; int32_t rcPos = (h.fwd) ? startPos2 : startPos1; if (fwPos < rcPos) { - leftMer.from_chars(txpStart + startPos1 - + leftMer.fromChars(txpStart + startPos1 - readBias1.contextBefore(read1RC)); - rightMer.from_chars(txpStart + startPos2 - + rightMer.fromChars(txpStart + startPos2 - readBias2.contextBefore(read2RC)); if (read1RC) { - leftMer.reverse_complement(); + leftMer.rc(); } else { - rightMer.reverse_complement(); + rightMer.rc(); } success = readBias1.addSequence(leftMer, 1.0); @@ -1514,8 +1517,8 @@ void processReads( auto& readBiasFW = observedBiasParams.seqBiasModelFW; auto& readBiasRC = observedBiasParams.seqBiasModelRC; - Mer context; - + //Mer context; + SBMer context; uint64_t firstTimestepOfRound = fmCalc.getCurrentTimestep(); size_t minK = qidx->k(); @@ -1762,10 +1765,10 @@ void processReads( // read start sequences. if (startPos >= readBias.contextBefore(!h.fwd) and startPos + readBias.contextAfter(!h.fwd) < static_cast(t.RefLength)) { - context.from_chars(txpStart + startPos - + context.fromChars(txpStart + startPos - readBias.contextBefore(!h.fwd)); if (!h.fwd) { - context.reverse_complement(); + context.rc(); } success = readBias.addSequence(context, 1.0); } diff --git a/src/SalmonQuantifyAlignments.cpp b/src/SalmonQuantifyAlignments.cpp index f9950317b..e01bf7f11 100644 --- a/src/SalmonQuantifyAlignments.cpp +++ b/src/SalmonQuantifyAlignments.cpp @@ -174,9 +174,13 @@ void processMiniBatch(AlignmentLibraryT& alnLib, auto orphanProb = salmonOpts.discardOrphansAln ? LOG_0 : LOG_EPSILON; // k-mers for sequence bias - Mer leftMer; - Mer rightMer; - Mer context; + //Mer leftMer; + //Mer rightMer; + //Mer context; + SBMer leftMer; + SBMer rightMer; + SBMer context; + auto& refs = alnLib.transcripts(); auto& clusterForest = alnLib.clusterForest(); @@ -656,14 +660,15 @@ void processMiniBatch(AlignmentLibraryT& alnLib, int32_t fwPos = (fwd1) ? startPos1 : startPos2; int32_t rcPos = (fwd1) ? startPos2 : startPos1; if (fwPos < rcPos) { - leftMer.from_chars(txpStart + startPos1 - + leftMer.fromChars(txpStart + startPos1 - readBias1.contextBefore(read1RC)); - rightMer.from_chars(txpStart + startPos2 - + rightMer.fromChars(txpStart + startPos2 - readBias2.contextBefore(read2RC)); + if (read1RC) { - leftMer.reverse_complement(); + leftMer.rc(); } else { - rightMer.reverse_complement(); + rightMer.rc(); } success = readBias1.addSequence(leftMer, 1.0); @@ -692,10 +697,11 @@ void processMiniBatch(AlignmentLibraryT& alnLib, if (startPos1 >= readBias1.contextBefore(!fwd1) and startPos1 + readBias1.contextAfter(!fwd1) < static_cast(transcript.RefLength)) { - context.from_chars(txpStart + startPos1 - + context.fromChars(txpStart + startPos1 - readBias1.contextBefore(!fwd1)); + if (!fwd1) { - context.reverse_complement(); + context.rc(); } success = readBias1.addSequence(context, 1.0); } diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index 27eec4578..c861c346e 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -37,7 +37,7 @@ #include "gff.h" #include "FastxParser.hpp" -#include "jellyfish/mer_dna.hpp" +//#include "jellyfish/mer_dna.hpp" #include "GenomicFeature.hpp" #include "SGSmooth.hpp" @@ -2656,10 +2656,14 @@ int contextSize = outsideContext + insideContext; revComplement(tseq, refLen, rcSeq); const char* rseq = rcSeq.c_str(); - Mer fwmer; - fwmer.from_chars(tseq); - Mer rcmer; - rcmer.from_chars(rseq); + // Mer fwmer; + // Mer rcmer; + // fwmer.from_chars(tseq); + //rcmer.from_chars(rseq); + SBMer fwmer; + SBMer rcmer; + fwmer.fromChars(tseq); + rcmer.fromChars(rseq); int32_t contextLength{expectSeqFW.getContextLength()}; if (gcBiasCorrect and seqBiasCorrect) { @@ -2692,8 +2696,10 @@ int contextSize = outsideContext + insideContext; } // shift the context one nucleotide to the right - fwmer.shift_left(tseq[fragStartPos + contextLength]); - rcmer.shift_left(rseq[fragStartPos + contextLength]); + //fwmer.shift_left(tseq[fragStartPos + contextLength]); + //rcmer.shift_left(rseq[fragStartPos + contextLength]); + fwmer.append(tseq[fragStartPos + contextLength]); + rcmer.append(rseq[fragStartPos + contextLength]); } // end: Seq-specific bias // fragment-GC bias @@ -2929,10 +2935,15 @@ int contextSize = outsideContext + insideContext; // and seqFactorsRC will contain the sequence-specific bias for each // position on the 3' strand. if (seqBiasCorrect) { - Mer mer; - Mer rcmer; - mer.from_chars(tseq); - rcmer.from_chars(rseq); + //Mer mer; + //Mer rcmer; + //mer.from_chars(tseq); + //rcmer.from_chars(rseq); + SBMer mer; + SBMer rcmer; + mer.fromChars(tseq); + rcmer.fromChars(rseq); + int32_t contextLength{exp5.getContextLength()}; for (int32_t fragStart = 0; fragStart < refLen - K; ++fragStart) { @@ -2948,8 +2959,10 @@ int contextSize = outsideContext + insideContext; exp3.evaluateLog(rcmer)); } // shift the context one nucleotide to the right - mer.shift_left(tseq[fragStart + contextLength]); - rcmer.shift_left(rseq[fragStart + contextLength]); + //mer.shift_left(tseq[fragStart + contextLength]); + //rcmer.shift_left(rseq[fragStart + contextLength]); + mer.append(tseq[fragStart + contextLength]); + rcmer.append(rseq[fragStart + contextLength]); } // We need these in 5' -> 3' order, so reverse them seqFactorsRC.reverseInPlace(); From 3b73ca1edd1aae2f02bef0838bde5a24b4ff1343 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sat, 2 May 2020 12:47:01 -0400 Subject: [PATCH 11/52] fetch pufferfish back from develop --- scripts/fetchPufferfish.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/fetchPufferfish.sh b/scripts/fetchPufferfish.sh index d49d77190..9f6cc880b 100755 --- a/scripts/fetchPufferfish.sh +++ b/scripts/fetchPufferfish.sh @@ -22,8 +22,8 @@ if [ -d ${INSTALL_DIR}/src/pufferfish ] ; then rm -fr ${INSTALL_DIR}/src/pufferfish fi -SVER=salmon-v1.2.1 -#SVER=develop +#SVER=salmon-v1.2.1 +SVER=develop EXPECTED_SHA256=da51713e54cf426524a2a1da7de2273cea7bf1f4089abbce22fcaa8f59e493cc From 7a69d7129c349ba9cffad41bbcb0793c3db3e0d2 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sun, 3 May 2020 00:35:32 -0400 Subject: [PATCH 12/52] Bump oneTBB v. fix bug in #514 This bumps the oneTBB version that is downloaded to 2020_U2, and sets the min required version to 2019 (actually, we need >= 2019_U4, but it's hard to check that). This also fixes the bug mentoined in #514, where using `--posBias` alone (in selective alignment mode) would lead to a segfault. It turns out we have an optimization to not load the ascii representation of transcripts into the transcript objects (just keep it in the index) if the user requests no bias correction or _only_ positional bias. However, the code that performs the bias correction could then attempt to reverse complement the non-existent sequence. When either --seqBias or --gcBias were enabled with --posBias, the sequence was loaded and so this bug wasn't triggered. Now, we added an explicit check so that we don't attempt to touch sequence that we don't need (and that doesn't exist), if we are in using only positional bias. Thanks @jasonvrogers for reporting this one! --- CMakeLists.txt | 22 ++++++++++++---------- include/Transcript.hpp | 2 ++ src/BuildSalmonIndex.cpp | 1 - src/CollapsedEMOptimizer.cpp | 6 ++++-- src/CollapsedGibbsSampler.cpp | 8 ++++++-- src/SalmonAlevin.cpp | 1 - src/SalmonUtils.cpp | 34 ++++++++++++++++++++++------------ 7 files changed, 46 insertions(+), 28 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 477359afd..71b0fb780 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -596,10 +596,15 @@ if (NOT CEREAL_FOUND) endif() ## Try and find TBB first -find_package(TBB 2018.0 COMPONENTS tbb tbbmalloc tbbmalloc_proxy) +find_package(TBB 2019.0 COMPONENTS tbb tbbmalloc tbbmalloc_proxy) +## NOTE: we actually require at least 2019 U4 or greater +## since we are using tbb::global_control. However, they +## seem not to have tagged minor version numbers in their +## source. Check before release if we can bump to the 2020 +## version (requires having tbb 2020 for OSX). if (${TBB_FOUND}) - if (${TBB_VERSION} VERSION_GREATER_EQUAL 2018.0) + if (${TBB_VERSION} VERSION_GREATER_EQUAL 2019.0) message("FOUND SUITABLE TBB VERSION : ${TBB_VERSION}") set(TBB_TARGET_EXISTED TRUE) else() @@ -627,7 +632,7 @@ endif() message("Build system will fetch and build Intel Threading Building Blocks") message("==================================================================") # These are useful for the custom install step we'll do later -set(TBB_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2019_U8) +set(TBB_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2020.2) set(TBB_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/install) if("${TBB_COMPILER}" STREQUAL "gcc") @@ -640,10 +645,10 @@ set(TBB_CXXFLAGS "${TBB_CXXFLAGS} ${CXXSTDFLAG}") externalproject_add(libtbb DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external - DOWNLOAD_COMMAND curl -k -L https://github.com/intel/tbb/archive/2019_U8.tar.gz -o tbb-2019_U8.tgz && - ${SHASUM} 6b540118cbc79f9cbc06a35033c18156c21b84ab7b6cf56d773b168ad2b68566 tbb-2019_U8.tgz && - tar -xzvf tbb-2019_U8.tgz - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2019_U8 + DOWNLOAD_COMMAND curl -k -L https://github.com/oneapi-src/oneTBB/archive/v2020.2.tar.gz -o tbb-2020_U2.tgz && + ${SHASUM} 4804320e1e6cbe3a5421997b52199e3c1a3829b2ecb6489641da4b8e32faf500 tbb-2020_U2.tgz && + tar -xzvf tbb-2020_U2.tgz + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/oneTBB-2020.2 INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/install PATCH_COMMAND "${TBB_PATCH_STEP}" CONFIGURE_COMMAND "" @@ -652,9 +657,6 @@ externalproject_add(libtbb BUILD_IN_SOURCE 1 ) - - - set(RECONFIG_FLAGS ${RECONFIG_FLAGS} -DTBB_WILL_RECONFIGURE=FALSE -DTBB_RECONFIGURE=TRUE) externalproject_add_step(libtbb reconfigure COMMAND ${CMAKE_COMMAND} ${CMAKE_CURRENT_SOURCE_DIR} ${RECONFIG_FLAGS} diff --git a/include/Transcript.hpp b/include/Transcript.hpp index 4edf6a79d..5f240a90e 100644 --- a/include/Transcript.hpp +++ b/include/Transcript.hpp @@ -480,6 +480,8 @@ class Transcript { } } + bool have_sequence() const { if (Sequence_) { return true; } else { return false; } } + const char* Sequence() const { return Sequence_.get(); } uint8_t* SAMSequence() const { return const_cast(SAMSequence_.data()); } diff --git a/src/BuildSalmonIndex.cpp b/src/BuildSalmonIndex.cpp index 8e97e69ad..d3e525adc 100644 --- a/src/BuildSalmonIndex.cpp +++ b/src/BuildSalmonIndex.cpp @@ -29,7 +29,6 @@ #include "tbb/parallel_for.h" #include "tbb/parallel_for_each.h" #include "tbb/parallel_sort.h" -#include "tbb/task_scheduler_init.h" #include "GenomicFeature.hpp" #include "SalmonIndex.hpp" diff --git a/src/CollapsedEMOptimizer.cpp b/src/CollapsedEMOptimizer.cpp index f3a142c73..cbc28e870 100644 --- a/src/CollapsedEMOptimizer.cpp +++ b/src/CollapsedEMOptimizer.cpp @@ -8,7 +8,8 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" +// <-- deprecated in TBB --> #include "tbb/task_scheduler_init.h" +#include "tbb/global_control.h" //#include "fastapprox.h" #include @@ -721,7 +722,8 @@ template bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, double relDiffTolerance, uint32_t maxIter) { - tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + // <-- deprecated in TBB --> tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + tbb::global_control c(tbb::global_control::max_allowed_parallelism, sopt.numThreads); std::vector& transcripts = readExp.transcripts(); std::vector available(transcripts.size(), false); diff --git a/src/CollapsedGibbsSampler.cpp b/src/CollapsedGibbsSampler.cpp index 60f30f6ad..8ca0b66d0 100644 --- a/src/CollapsedGibbsSampler.cpp +++ b/src/CollapsedGibbsSampler.cpp @@ -12,7 +12,8 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" +// <-- deprecated in TBB --> #include "tbb/task_scheduler_init.h" +#include "tbb/global_control.h" //#include "fastapprox.h" #include @@ -308,7 +309,10 @@ bool CollapsedGibbsSampler::sample( namespace bfs = boost::filesystem; auto& jointLog = sopt.jointLog; - tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + + // <-- deprecated in TBB --> tbb::task_scheduler_init tbbScheduler(sopt.numThreads); + tbb::global_control c(tbb::global_control::max_allowed_parallelism, sopt.numThreads); + std::vector& transcripts = readExp.transcripts(); // Fill in the effective length vector diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index f7cb1de83..7d5d6f5de 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -70,7 +70,6 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" // logger includes #include "spdlog/spdlog.h" diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index c861c346e..cd54d0f66 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -2652,18 +2652,21 @@ int contextSize = outsideContext + insideContext; windowLensTP.setZero(); // This transcript's sequence - const char* tseq = txp.Sequence(); - revComplement(tseq, refLen, rcSeq); - const char* rseq = rcSeq.c_str(); - - // Mer fwmer; - // Mer rcmer; - // fwmer.from_chars(tseq); - //rcmer.from_chars(rseq); + bool have_seq = txp.have_sequence(); SBMer fwmer; SBMer rcmer; - fwmer.fromChars(tseq); - rcmer.fromChars(rseq); + + const char* tseq = have_seq ? txp.Sequence() : nullptr; + // only do this if we have the sequence, we may not if just pos. bias. + if (have_seq) { + revComplement(tseq, refLen, rcSeq); + fwmer.fromChars(tseq); + rcmer.fromChars(rcSeq); + } + // may be empty, but we shouldn't actually + // use it unless it is meaningful. + const char* rseq = rcSeq.c_str(); + int32_t contextLength{expectSeqFW.getContextLength()}; if (gcBiasCorrect and seqBiasCorrect) { @@ -2893,8 +2896,15 @@ int contextSize = outsideContext + insideContext; std::vector posFactorsRC(refLen, 1.0); // This transcript's sequence - const char* tseq = txp.Sequence(); - revComplement(tseq, refLen, rcSeq); + bool have_seq = txp.have_sequence(); + const char* tseq = have_seq ? txp.Sequence() : nullptr; + // only do this if we have the sequence, we may not if just pos. + // bias. + if (have_seq) { + revComplement(tseq, refLen, rcSeq); + } + // may be empty, but we shouldn't actually + // use it unless it is meaningful. const char* rseq = rcSeq.c_str(); int32_t fl = locFLDLow; From b327df273fbc7e41f78cae3dbc0a6e0698472052 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Tue, 5 May 2020 00:46:26 -0400 Subject: [PATCH 13/52] replace deprecated tbb functionality. This commit removes references to the tbb/atomic header and tbb/task_scheduler_init header. The replacement for tbb::atomic is std::atomic, which is trivial for built-in types, but required some work for doubles. @k3yavi & @hiraksarkar : please look over this commit specifically for replacements of tbb::atomic with std::atomic as it relates to alevin (SalmonAlevin, CollapsedCellOptimizer, etc.), but also please look over in general. The biggest change / pita is how std::vector> requires nothing that uses the copy constructor. --- include/AtomicMatrix.hpp | 95 ++++++++----------- include/BAMQueue.hpp | 1 - include/CollapsedCellOptimizer.hpp | 9 +- include/CollapsedEMOptimizer.hpp | 6 +- include/CollapsedGibbsSampler.hpp | 4 - include/EMUtils.hpp | 2 +- include/FragmentLengthDistribution.hpp | 7 +- include/FragmentStartPositionDistribution.hpp | 8 +- include/SalmonUtils.hpp | 18 ++-- include/Sampler.hpp | 1 - include/Transcript.hpp | 9 +- src/AlignmentModel.cpp | 15 +-- src/CollapsedCellOptimizer.cpp | 8 +- src/CollapsedEMOptimizer.cpp | 12 +-- src/EMUtils.cpp | 8 +- src/FragmentLengthDistribution.cpp | 48 ++++------ src/FragmentStartPositionDistribution.cpp | 13 +-- src/SalmonQuantify.cpp | 1 - src/SalmonQuantifyAlignments.cpp | 1 - src/SalmonUtils.cpp | 56 ++--------- 20 files changed, 119 insertions(+), 203 deletions(-) diff --git a/include/AtomicMatrix.hpp b/include/AtomicMatrix.hpp index ed998ab00..0212da8de 100644 --- a/include/AtomicMatrix.hpp +++ b/include/AtomicMatrix.hpp @@ -1,42 +1,53 @@ #ifndef ATOMIC_MATRIX #define ATOMIC_MATRIX -#include "tbb/atomic.h" #include "tbb/concurrent_vector.h" #include "SalmonMath.hpp" +#include "SalmonUtils.hpp" #include template class AtomicMatrix { public: + AtomicMatrix() { + nRow_ = 0; + nCol_ = 0; + alpha_ = salmon::math::LOG_0; + logSpace_ = true; + } + AtomicMatrix(size_t nRow, size_t nCol, T alpha, bool logSpace = true) - : storage_(nRow * nCol, logSpace ? std::log(alpha) : alpha), - rowsums_(nRow, logSpace ? std::log(nCol * alpha) : nCol * alpha), - nRow_(nRow), nCol_(nCol), alpha_(alpha), logSpace_(logSpace) {} + : nRow_(nRow), nCol_(nCol), alpha_(alpha), logSpace_(logSpace) { + + decltype(storage_) storage_tmp(nRow * nCol); + std::swap(storage_, storage_tmp); + T e = logSpace ? std::log(alpha) : alpha; + std::fill(storage_.begin(), storage_.end(), e); + + decltype(rowsums_) rowsums_tmp(nRow); + std::swap(rowsums_, rowsums_tmp); + T ers = logSpace ? std::log(nCol * alpha) : nCol * alpha; + std::fill(rowsums_.begin(), rowsums_.end(), ers); + } + + AtomicMatrix& operator=(AtomicMatrix&& o) { + std::swap(storage_, o.storage_); + std::swap(rowsums_, o.rowsums_); + nRow_ = o.nRow_; + nCol_ = o.nCol_; + alpha_ = o.alpha_; + logSpace_ = o.logSpace_; + return *this; + } void incrementUnnormalized(size_t rowInd, size_t colInd, T amt) { using salmon::math::logAdd; size_t k = rowInd * nCol_ + colInd; if (logSpace_) { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = logAdd(oldVal, amt); - do { - oldVal = retVal; - newVal = logAdd(oldVal, amt); - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - + salmon::utils::incLoopLog(storage_[k], amt); } else { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = oldVal + amt; - do { - oldVal = retVal; - newVal = oldVal + amt; - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoop(storage_[k], amt); } } @@ -55,41 +66,11 @@ template class AtomicMatrix { using salmon::math::logAdd; size_t k = rowInd * nCol_ + colInd; if (logSpace_) { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = logAdd(oldVal, amt); - do { - oldVal = retVal; - newVal = logAdd(oldVal, amt); - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - oldVal = rowsums_[rowInd]; - retVal = oldVal; - newVal = logAdd(oldVal, amt); - do { - oldVal = retVal; - newVal = logAdd(oldVal, amt); - retVal = rowsums_[rowInd].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoopLog(storage_[k], amt); + salmon::utils::incLoopLog(rowsums_[rowInd], amt); } else { - T oldVal = storage_[k]; - T retVal = oldVal; - T newVal = oldVal + amt; - do { - oldVal = retVal; - newVal = oldVal + amt; - retVal = storage_[k].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - oldVal = rowsums_[rowInd]; - retVal = oldVal; - newVal = oldVal + amt; - do { - oldVal = retVal; - newVal = oldVal + amt; - retVal = rowsums_[rowInd].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoop(storage_[k], amt); + salmon::utils::incLoop(rowsums_[rowInd], amt); } } @@ -106,8 +87,8 @@ template class AtomicMatrix { size_t nCol() const { return nCol_; } private: - std::vector> storage_; - std::vector> rowsums_; + std::vector> storage_; + std::vector> rowsums_; size_t nRow_, nCol_; T alpha_; bool logSpace_; diff --git a/include/BAMQueue.hpp b/include/BAMQueue.hpp index 9576b234e..d5df05c7d 100644 --- a/include/BAMQueue.hpp +++ b/include/BAMQueue.hpp @@ -10,7 +10,6 @@ #include #include #include -#include #include #include diff --git a/include/CollapsedCellOptimizer.hpp b/include/CollapsedCellOptimizer.hpp index 6ca46bb24..2aeaefd46 100644 --- a/include/CollapsedCellOptimizer.hpp +++ b/include/CollapsedCellOptimizer.hpp @@ -6,9 +6,6 @@ #include #include -#include "tbb/atomic.h" -#include "tbb/task_scheduler_init.h" - #include #include "ReadExperiment.hpp" @@ -45,7 +42,7 @@ struct CellState { class CollapsedCellOptimizer { public: - using VecType = std::vector>; + using VecType = std::vector>; using SerialVecType = std::vector; CollapsedCellOptimizer(); @@ -77,8 +74,8 @@ void optimizeCell(std::vector& trueBarcodes, std::vector& umiCount, std::vector& skippedCB, bool verbose, GZipWriter& gzw, bool noEM, bool useVBEM, - bool quiet, tbb::atomic& totalDedupCounts, - tbb::atomic& totalExpGeneCounts, double priorWeight, + bool quiet, std::atomic& totalDedupCounts, + std::atomic& totalExpGeneCounts, double priorWeight, spp::sparse_hash_map& txpToGeneMap, uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, uint32_t numGibbsSamples, diff --git a/include/CollapsedEMOptimizer.hpp b/include/CollapsedEMOptimizer.hpp index f2b406e8c..397ae4bc4 100644 --- a/include/CollapsedEMOptimizer.hpp +++ b/include/CollapsedEMOptimizer.hpp @@ -1,12 +1,10 @@ #ifndef COLLAPSED_EM_OPTIMIZER_HPP #define COLLAPSED_EM_OPTIMIZER_HPP +#include #include #include -#include "tbb/atomic.h" -#include "tbb/task_scheduler_init.h" - #include "ReadExperiment.hpp" #include "SalmonOpts.hpp" @@ -17,7 +15,7 @@ class BootstrapWriter; class CollapsedEMOptimizer { public: - using VecType = std::vector>; + using VecType = std::vector>; using SerialVecType = std::vector; CollapsedEMOptimizer(); diff --git a/include/CollapsedGibbsSampler.hpp b/include/CollapsedGibbsSampler.hpp index db8e8cba5..c5196dfde 100644 --- a/include/CollapsedGibbsSampler.hpp +++ b/include/CollapsedGibbsSampler.hpp @@ -3,10 +3,6 @@ #include #include - -#include "tbb/atomic.h" -#include "tbb/task_scheduler_init.h" - #include "SalmonOpts.hpp" #include "Eigen/Dense" diff --git a/include/EMUtils.hpp b/include/EMUtils.hpp index 3d2d543f4..0a4eada73 100644 --- a/include/EMUtils.hpp +++ b/include/EMUtils.hpp @@ -2,7 +2,7 @@ #define EM_UTILS_HPP #include -#include "tbb/atomic.h" +#include #include "Transcript.hpp" template diff --git a/include/FragmentLengthDistribution.hpp b/include/FragmentLengthDistribution.hpp index e9ec6ff53..84fbbd59e 100644 --- a/include/FragmentLengthDistribution.hpp +++ b/include/FragmentLengthDistribution.hpp @@ -10,7 +10,6 @@ #ifndef FRAGMENT_LENGTH_DISTRIBUTION #define FRAGMENT_LENGTH_DISTRIBUTION -#include "tbb/atomic.h" #include #include #include @@ -33,7 +32,7 @@ class FragmentLengthDistribution { /** * A private vector that stores the observed (logged) mass for each length. */ - std::vector> hist_; + std::vector> hist_; /** * A private vector that stores the observed (logged) mass for each length. @@ -47,12 +46,12 @@ class FragmentLengthDistribution { /** * A private double that stores the total observed (logged) mass. */ - tbb::atomic totMass_; + std::atomic totMass_; /** * A private double that stores the (logged) sum of the product of observed * lengths and masses for quick mean calculations. */ - tbb::atomic sum_; + std::atomic sum_; /** * A private int that stores the minimum observed length. */ diff --git a/include/FragmentStartPositionDistribution.hpp b/include/FragmentStartPositionDistribution.hpp index c400096f8..4d9dd6f1f 100644 --- a/include/FragmentStartPositionDistribution.hpp +++ b/include/FragmentStartPositionDistribution.hpp @@ -9,7 +9,7 @@ #ifndef FRAGMENT_START_POSITION_DISTRIBUTION #define FRAGMENT_START_POSITION_DISTRIBUTION -#include "tbb/atomic.h" +// #include "tbb/atomic.h" #include #include #include @@ -26,12 +26,12 @@ class FragmentStartPositionDistribution { /** * A private vector that stores the observed (logged) mass for each length. */ - std::vector> pmf_; - std::vector> cmf_; + std::vector> pmf_; + std::vector> cmf_; /** * A private double that stores the total observed (logged) mass. */ - tbb::atomic totMass_; + std::atomic totMass_; /** * A private double that stores the (logged) sum of the product of observed * lengths and masses for quick mean calculations. diff --git a/include/SalmonUtils.hpp b/include/SalmonUtils.hpp index 6ed0c54cd..3a1cd8282 100644 --- a/include/SalmonUtils.hpp +++ b/include/SalmonUtils.hpp @@ -159,15 +159,12 @@ updateEffectiveLengths(SalmonOpts& sopt, ReadExpT& readExp, * val + inc (*in log-space*). Update occurs in a loop in case other * threads update in the meantime. */ -inline void incLoopLog(tbb::atomic& val, double inc) { +inline void incLoopLog(std::atomic& val, double inc) { double oldMass = val.load(); - double returnedMass = oldMass; - double newMass{salmon::math::LOG_0}; + double newMass; do { - oldMass = returnedMass; newMass = salmon::math::logAdd(oldMass, inc); - returnedMass = val.compare_and_swap(newMass, oldMass); - } while (returnedMass != oldMass); + } while (! val.compare_exchange_strong(oldMass, newMass)); } /* @@ -180,6 +177,7 @@ inline void incLoop(double& val, double inc) { val += inc; } * val + inc. Update occurs in a loop in case other * threads update in the meantime. */ +/* inline void incLoop(tbb::atomic& val, double inc) { double oldMass = val.load(); double returnedMass = oldMass; @@ -189,6 +187,14 @@ inline void incLoop(tbb::atomic& val, double inc) { newMass = oldMass + inc; returnedMass = val.compare_and_swap(newMass, oldMass); } while (returnedMass != oldMass); +}*/ + +inline void incLoop(std::atomic& val, double inc) { + double oldMass = val.load(); + double newMass; + do { + newMass = oldMass + inc; + } while (!val.compare_exchange_strong(oldMass, newMass)); } std::string getCurrentTimeAsString(); diff --git a/include/Sampler.hpp b/include/Sampler.hpp index 1bbf88f53..2518d7bc2 100644 --- a/include/Sampler.hpp +++ b/include/Sampler.hpp @@ -20,7 +20,6 @@ extern "C" { #include #include #include -#include #include #include #include diff --git a/include/Transcript.hpp b/include/Transcript.hpp index 5f240a90e..ac02aa5b9 100644 --- a/include/Transcript.hpp +++ b/include/Transcript.hpp @@ -7,7 +7,6 @@ #include "SalmonStringUtils.hpp" #include "SalmonUtils.hpp" #include "SequenceBiasModel.hpp" -#include "tbb/atomic.h" #include "stx/string_view.hpp" #include "IOUtils.hpp" #include @@ -677,10 +676,10 @@ class Transcript { std::atomic uniqueCount_; std::atomic totalCount_; double priorMass_; - tbb::atomic mass_; - tbb::atomic sharedCount_; - tbb::atomic cachedEffectiveLength_; - tbb::atomic avgMassBias_; + std::atomic mass_; + std::atomic sharedCount_; + std::atomic cachedEffectiveLength_; + std::atomic avgMassBias_; uint32_t lengthClassIndex_; double logPerBasePrior_; // In a paired-end protocol, a transcript has diff --git a/src/AlignmentModel.cpp b/src/AlignmentModel.cpp index 23dab30e3..5377e423f 100644 --- a/src/AlignmentModel.cpp +++ b/src/AlignmentModel.cpp @@ -13,14 +13,15 @@ #include "UnpairedRead.hpp" AlignmentModel::AlignmentModel(double alpha, uint32_t readBins) - : transitionProbsLeft_(readBins, - AtomicMatrix(numAlignmentStates(), - numAlignmentStates(), alpha)), - transitionProbsRight_(readBins, - AtomicMatrix(numAlignmentStates(), - numAlignmentStates(), alpha)), + : transitionProbsLeft_(readBins), transitionProbsRight_(readBins), isEnabled_(true), readBins_(readBins), burnedIn_(false) { - + + for (size_t i = 0; i < readBins; ++i) { + transitionProbsLeft_[i] = std::move(AtomicMatrix( + numAlignmentStates(), numAlignmentStates(), alpha)); + transitionProbsRight_[i] = std::move(AtomicMatrix( + numAlignmentStates(), numAlignmentStates(), alpha)); + } } bool AlignmentModel::burnedIn() { return burnedIn_; } diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index 3c8c2e2ef..4ded743e3 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -556,8 +556,8 @@ void optimizeCell(std::vector& trueBarcodes, std::vector& umiCount, std::vector& skippedCB, bool verbose, GZipWriter& gzw, bool noEM, bool useVBEM, - bool quiet, tbb::atomic& totalDedupCounts, - tbb::atomic& totalExpGeneCounts, double priorWeight, + bool quiet, std::atomic& totalDedupCounts, + std::atomic& totalExpGeneCounts, double priorWeight, spp::sparse_hash_map& txpToGeneMap, uint32_t numGenes, uint32_t umiLength, uint32_t numBootstraps, uint32_t numGibbsSamples, @@ -1119,8 +1119,8 @@ bool CollapsedCellOptimizer::optimize(EqMapT& fullEqMap, double priorWeight {1.0}; std::atomic bcount{0}; - tbb::atomic totalDedupCounts{0.0}; - tbb::atomic totalExpGeneCounts{0}; + std::atomic totalDedupCounts{0.0}; + std::atomic totalExpGeneCounts{0}; std::atomic totalBiEdgesCounts{0}; std::atomic totalUniEdgesCounts{0}; diff --git a/src/CollapsedEMOptimizer.cpp b/src/CollapsedEMOptimizer.cpp index cbc28e870..46fc121f1 100644 --- a/src/CollapsedEMOptimizer.cpp +++ b/src/CollapsedEMOptimizer.cpp @@ -42,7 +42,7 @@ constexpr double minWeight = std::numeric_limits::min(); // A bit more conservative of a minimum as an argument to the digamma function. constexpr double digammaMin = 1e-10; -double normalize(std::vector>& vec) { +double normalize(std::vector>& vec) { double sum{0.0}; for (auto& v : vec) { sum += v; @@ -738,8 +738,8 @@ bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, using VecT = CollapsedEMOptimizer::VecType; // With atomics - VecType alphas(transcripts.size(), 0.0); - VecType alphasPrime(transcripts.size(), 0.0); + VecType alphas(transcripts.size()); + VecType alphasPrime(transcripts.size()); VecType expTheta(transcripts.size()); Eigen::VectorXd effLens(transcripts.size()); @@ -800,7 +800,7 @@ bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, // over to the alphas if (sopt.initUniform) { for (size_t i = 0; i < alphas.size(); ++i) { - alphas[i] = alphasPrime[i]; + alphas[i].store(alphasPrime[i].load()); alphasPrime[i] = 1.0; } } else { // otherwise, initalize with a linear combination of the true and @@ -942,8 +942,8 @@ bool CollapsedEMOptimizer::optimize(ExpT& readExp, SalmonOpts& sopt, converged = false; } } - alphas[i] = alphasPrime[i]; - alphasPrime[i] = 0.0; + alphas[i].store(alphasPrime[i].load()); + alphasPrime[i].store(0.0); } /* -- v0.8.x diff --git a/src/EMUtils.cpp b/src/EMUtils.cpp index 09e8c608a..30b42e824 100644 --- a/src/EMUtils.cpp +++ b/src/EMUtils.cpp @@ -74,14 +74,14 @@ void EMUpdate_>(std::vector>& txpGroup std::vector& alphaOut); template -void EMUpdate_>>(std::vector>& txpGroupLabels, +void EMUpdate_>>(std::vector>& txpGroupLabels, std::vector>& txpGroupCombinedWeights, const std::vector& txpGroupCounts, - const std::vector>& alphaIn, - std::vector>& alphaOut); + const std::vector>& alphaIn, + std::vector>& alphaOut); template double truncateCountVector>(std::vector& alphas, double cutoff); template -double truncateCountVector>>(std::vector>& alphas, double cutoff); +double truncateCountVector>>(std::vector>& alphas, double cutoff); diff --git a/src/FragmentLengthDistribution.cpp b/src/FragmentLengthDistribution.cpp index 216e442a3..ff4c3dd09 100644 --- a/src/FragmentLengthDistribution.cpp +++ b/src/FragmentLengthDistribution.cpp @@ -9,6 +9,7 @@ #include "FragmentLengthDistribution.hpp" #include "SalmonMath.hpp" +#include "SalmonUtils.hpp" #include #include #include @@ -22,7 +23,7 @@ using namespace std; FragmentLengthDistribution::FragmentLengthDistribution( double alpha, size_t max_val, double prior_mu, double prior_sigma, size_t kernel_n, double kernel_p, size_t bin_size) - : hist_(max_val / bin_size + 1), cachedCMF_(hist_.size()), + : /*hist_(max_val / bin_size + 1),*/ cachedCMF_(hist_.size()), haveCachedCMF_(false), totMass_(salmon::math::LOG_0), sum_(salmon::math::LOG_0), min_(max_val / bin_size), binSize_(bin_size) { @@ -38,6 +39,9 @@ FragmentLengthDistribution::FragmentLengthDistribution( boost::math::normal norm(prior_mu / bin_size, prior_sigma / (bin_size * bin_size)); + std::vector> hist_tmp(max_val / bin_size + 1); + std::swap(hist_, hist_tmp); + for (size_t i = 0; i <= max_val; ++i) { double norm_mass = boost::math::cdf(norm, i + 0.5) - boost::math::cdf(norm, i - 0.5); @@ -45,16 +49,17 @@ FragmentLengthDistribution::FragmentLengthDistribution( if (norm_mass != 0) { mass = tot + log(norm_mass); } - hist_[i].compare_and_swap(mass, hist_[i]); - sum_.compare_and_swap(logAdd(sum_, log((double)i) + mass), sum_); - totMass_.compare_and_swap(logAdd(totMass_, mass), totMass_); + hist_[i].store(mass); + sum_.store(logAdd(sum_, log((double)i) + mass)); + totMass_.store(logAdd(totMass_, mass)); } } else { - hist_ = - vector>(max_val + 1, tot - log((double)max_val)); - hist_[0].compare_and_swap(salmon::math::LOG_0, hist_[0]); - sum_.compare_and_swap( - hist_[1] + log((double)(max_val * (max_val + 1))) - log(2.), sum_); + std::vector> hist_tmp(max_val + 1); + std::swap(hist_, hist_tmp); + std::fill(hist_.begin(), hist_.end(), tot - log((double)max_val)); + + hist_[0].store(salmon::math::LOG_0); + sum_.store(hist_[1] + log((double)(max_val * (max_val + 1))) - log(2.)); totMass_ = tot; } @@ -96,28 +101,9 @@ void FragmentLengthDistribution::addVal(size_t len, double mass) { for (size_t i = 0; i < kernel_.size(); i++) { if (offset > 0 && offset < hist_.size()) { double kMass = mass + kernel_[i]; - double oldVal = hist_[offset]; - double retVal = oldVal; - double newVal = 0.0; - do { - oldVal = retVal; - newVal = logAdd(oldVal, kMass); - retVal = hist_[offset].compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - retVal = sum_; - do { - oldVal = retVal; - newVal = logAdd(oldVal, log(static_cast(offset)) + kMass); - retVal = sum_.compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); - - retVal = totMass_; - do { - oldVal = retVal; - newVal = logAdd(oldVal, kMass); - retVal = totMass_.compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + salmon::utils::incLoopLog(hist_[offset], kMass); + salmon::utils::incLoopLog(sum_, std::log(static_cast(offset)) + kMass); + salmon::utils::incLoopLog(totMass_, kMass); } offset++; } diff --git a/src/FragmentStartPositionDistribution.cpp b/src/FragmentStartPositionDistribution.cpp index 8467c659b..f89e667e7 100644 --- a/src/FragmentStartPositionDistribution.cpp +++ b/src/FragmentStartPositionDistribution.cpp @@ -34,15 +34,12 @@ FragmentStartPositionDistribution::FragmentStartPositionDistribution( } } -inline void logAddMass(tbb::atomic& bin, double newMass) { - double oldVal = bin; - double retVal = oldVal; - double newVal = 0.0; - do { - oldVal = retVal; +inline void logAddMass(std::atomic& bin, double newMass) { + double oldVal = bin.load(); + double newVal; + do{ newVal = salmon::math::logAdd(oldVal, newMass); - retVal = bin.compare_and_swap(newVal, oldVal); - } while (retVal != oldVal); + } while (!bin.compare_exchange_strong(oldVal, newVal)); } void FragmentStartPositionDistribution::addVal(int32_t hitPos, uint32_t txpLen, diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 19eb2b4bb..3d77d37c7 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -69,7 +69,6 @@ #include "tbb/parallel_for_each.h" #include "tbb/parallel_reduce.h" #include "tbb/partitioner.h" -#include "tbb/task_scheduler_init.h" // logger includes #include "spdlog/spdlog.h" diff --git a/src/SalmonQuantifyAlignments.cpp b/src/SalmonQuantifyAlignments.cpp index e01bf7f11..e6d4e5c02 100644 --- a/src/SalmonQuantifyAlignments.cpp +++ b/src/SalmonQuantifyAlignments.cpp @@ -16,7 +16,6 @@ extern "C" { #include #include #include -#include #include #include #include diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index cd54d0f66..291dbb8a8 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -3332,59 +3332,19 @@ template void salmon::utils::normalizeAlphas>( template void salmon::utils::normalizeAlphas>( const SalmonOpts& sopt, BulkAlignLibT& alnLib); -// explicit instantiations for effective length updates --- -/* -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, - ReadExperiment>( - SalmonOpts& sopt, ReadExperiment& readExp, Eigen::VectorXd& effLensIn, - std::vector>& alphas, bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths, ReadExperiment>( - SalmonOpts& sopt, ReadExperiment& readExp, Eigen::VectorXd& effLensIn, - std::vector& alphas, bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, - bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector& alphas, bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, - bool finalRound); - -template Eigen::VectorXd -salmon::utils::updateEffectiveLengths, - AlignmentLibrary>( - SalmonOpts& sopt, AlignmentLibrary& readExp, - Eigen::VectorXd& effLensIn, std::vector& alphas, bool finalRound); -*/ - // explicit instantiations for effective length updates --- template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, BulkExpT>( SalmonOpts& sopt, BulkExpT& readExp, Eigen::VectorXd& effLensIn, - std::vector>& alphas, std::vector& available, + std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, SCExpT>( SalmonOpts& sopt, SCExpT& readExp, Eigen::VectorXd& effLensIn, - std::vector>& alphas, std::vector& available, + std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd @@ -3398,10 +3358,10 @@ salmon::utils::updateEffectiveLengths, SCExpT>( std::vector& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, BulkAlignLibT>( SalmonOpts& sopt, BulkAlignLibT& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, + Eigen::VectorXd& effLensIn, std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd @@ -3412,10 +3372,10 @@ salmon::utils::updateEffectiveLengths, std::vector& available, bool finalRound); template Eigen::VectorXd -salmon::utils::updateEffectiveLengths>, +salmon::utils::updateEffectiveLengths>, BulkAlignLibT>( SalmonOpts& sopt, BulkAlignLibT& readExp, - Eigen::VectorXd& effLensIn, std::vector>& alphas, + Eigen::VectorXd& effLensIn, std::vector>& alphas, std::vector& available, bool finalRound); template Eigen::VectorXd From d97efe26adf7d95a3bd53311a2a3f378e7063b2e Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 8 May 2020 11:24:56 -0400 Subject: [PATCH 14/52] write fld mean and std to aux_info.json --- include/DistributionUtils.hpp | 12 +++++++++++- include/SalmonUtils.hpp | 11 ----------- src/DistributionUtils.cpp | 14 +++++++++++--- src/GZipWriter.cpp | 10 +++++++--- 4 files changed, 29 insertions(+), 18 deletions(-) diff --git a/include/DistributionUtils.hpp b/include/DistributionUtils.hpp index 99d266d3e..99d684941 100644 --- a/include/DistributionUtils.hpp +++ b/include/DistributionUtils.hpp @@ -12,13 +12,23 @@ class Transcript; namespace distribution_utils { enum class DistributionSpace : uint8_t { LOG = 0, LINEAR = 1 }; +class DistSummary { +public: +DistSummary(double mean_in, double sd_in, uint32_t support) : + mean(mean_in), sd(sd_in), samples(support,0) {} + +double mean; +double sd; +std::vector samples; +}; + /** * Draw samples from the provided fragment length distribution. * \param fld A pointer to the FragmentLengthDistribution from which * samples will be drawn. * \param numSamples The number of samples to draw. */ -std::vector samplesFromLogPMF(FragmentLengthDistribution* fld, +DistSummary samplesFromLogPMF(FragmentLengthDistribution* fld, int32_t numSamples); /** diff --git a/include/SalmonUtils.hpp b/include/SalmonUtils.hpp index 3a1cd8282..04a452532 100644 --- a/include/SalmonUtils.hpp +++ b/include/SalmonUtils.hpp @@ -177,17 +177,6 @@ inline void incLoop(double& val, double inc) { val += inc; } * val + inc. Update occurs in a loop in case other * threads update in the meantime. */ -/* -inline void incLoop(tbb::atomic& val, double inc) { - double oldMass = val.load(); - double returnedMass = oldMass; - double newMass{oldMass + inc}; - do { - oldMass = returnedMass; - newMass = oldMass + inc; - returnedMass = val.compare_and_swap(newMass, oldMass); - } while (returnedMass != oldMass); -}*/ inline void incLoop(std::atomic& val, double inc) { double oldMass = val.load(); diff --git a/src/DistributionUtils.cpp b/src/DistributionUtils.cpp index 538c4a0ff..77457b580 100644 --- a/src/DistributionUtils.cpp +++ b/src/DistributionUtils.cpp @@ -54,7 +54,7 @@ void computeSmoothedEffectiveLengths(size_t maxLength, } } -std::vector samplesFromLogPMF(FragmentLengthDistribution* fld, +DistSummary samplesFromLogPMF(FragmentLengthDistribution* fld, int32_t numSamples) { std::vector logPMF; size_t minVal; @@ -70,21 +70,29 @@ std::vector samplesFromLogPMF(FragmentLengthDistribution* fld, } // Create the non-logged pmf + double mean = salmon::math::LOG_0; + double var = 0.0; std::vector pmf(maxVal + 1, 0.0); for (size_t i = minVal; i < maxVal; ++i) { + mean = (i > 0) ? (salmon::math::logAdd(mean, std::log(i)+logPMF[i-minVal])) : mean; pmf[i] = std::exp(logPMF[i - minVal]); + var += pmf[i] * (i*i); } + mean = std::exp(mean); + var -= mean*mean; + double sd = std::sqrt(var); // generate samples std::random_device rd; std::mt19937 gen(rd()); std::discrete_distribution dist(pmf.begin(), pmf.end()); + DistSummary ds(mean, sd, pmf.size()); std::vector samples(pmf.size()); for (int32_t i = 0; i < numSamples; ++i) { - ++samples[dist(gen)]; + ++(ds.samples[dist(gen)]); } - return samples; + return ds; } diff --git a/src/GZipWriter.cpp b/src/GZipWriter.cpp index 85c604437..461e3365e 100644 --- a/src/GZipWriter.cpp +++ b/src/GZipWriter.cpp @@ -407,6 +407,8 @@ bool GZipWriter::writeEmptyMeta(const SalmonOpts& opts, const ExpT& experiment, oa(cereal::make_nvp("library_types", libStrings)); oa(cereal::make_nvp("frag_dist_length", 0)); + oa(cereal::make_nvp("frag_length_mean", 0.0)); + oa(cereal::make_nvp("frag_length_sd", 0.0)); oa(cereal::make_nvp("seq_bias_correct", false)); oa(cereal::make_nvp("gc_bias_correct", false)); oa(cereal::make_nvp("num_bias_bins", 0)); @@ -571,9 +573,9 @@ bool GZipWriter::writeMeta(const SalmonOpts& opts, const ExpT& experiment, const bfs::path fldPath = auxDir / "fld.gz"; int32_t numFLDSamples{10000}; - auto fldSamples = distribution_utils::samplesFromLogPMF( + auto fldSummary = distribution_utils::samplesFromLogPMF( experiment.fragmentLengthDistribution(), numFLDSamples); - writeVectorToFile(fldPath, fldSamples); + writeVectorToFile(fldPath, fldSummary.samples); bfs::path normBiasPath = auxDir / "expected_bias.gz"; writeVectorToFile(normBiasPath, experiment.expectedSeqBias()); @@ -774,7 +776,9 @@ bool GZipWriter::writeMeta(const SalmonOpts& opts, const ExpT& experiment, const oa(cereal::make_nvp("num_libraries", libStrings.size())); oa(cereal::make_nvp("library_types", libStrings)); - oa(cereal::make_nvp("frag_dist_length", fldSamples.size())); + oa(cereal::make_nvp("frag_dist_length", fldSummary.samples.size())); + oa(cereal::make_nvp("frag_length_mean", fldSummary.mean)); + oa(cereal::make_nvp("frag_length_sd", fldSummary.sd)); oa(cereal::make_nvp("seq_bias_correct", opts.biasCorrect)); oa(cereal::make_nvp("gc_bias_correct", opts.gcBiasCorrect)); oa(cereal::make_nvp("num_bias_bins", bcounts.size())); From 081291449add44b71ee17716adc38a4d2ff35bfd Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Mon, 1 Jun 2020 23:22:25 -0400 Subject: [PATCH 15/52] build with homopolymer optimization --- src/SalmonQuantify.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 3d77d37c7..3a2708c86 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -1219,6 +1219,13 @@ void processReads( numDecoyFrags += bestHitDecoy ? 1 : 0; ++numFragsDropped; jointAlignmentGroup.clearAlignments(); + // TODO: Create alignment objects for the decoys so that we can write + // decoy alignments to file + /** + if (bestHitDecoy) { + salmon::mapping_utils::filterAndCollectAlignmentsDecoy(...); + } + **/ } } else if (isPaired and noDovetail) { salmonOpts.jointLog->critical("This code path is not yet implemented!"); From d6abefd699d8c4e34ade5f01ea7e4686d28c05ff Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Wed, 3 Jun 2020 15:25:20 -0400 Subject: [PATCH 16/52] fixing custom citeseq --- src/Alevin.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Alevin.cpp b/src/Alevin.cpp index 7cdd19900..998965f96 100644 --- a/src/Alevin.cpp +++ b/src/Alevin.cpp @@ -1003,7 +1003,7 @@ salmon-based processing of single-cell RNA-seq data. if (celseq) validate_num_protocols += 1; if (celseq2) validate_num_protocols += 1; if (quartzseq2) validate_num_protocols += 1; - if (custom) validate_num_protocols += 1; + if (custom and !noTgMap) validate_num_protocols += 1; if ( validate_num_protocols != 1 ) { fmt::print(stderr, "ERROR: Please specify one and only one scRNA protocol;"); From c995ab609a782a07bbc2f4a346aa222750fbff99 Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Wed, 3 Jun 2020 16:01:32 -0400 Subject: [PATCH 17/52] flusing the log --- src/Alevin.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Alevin.cpp b/src/Alevin.cpp index 998965f96..2fb8d4235 100644 --- a/src/Alevin.cpp +++ b/src/Alevin.cpp @@ -823,6 +823,7 @@ void initiatePipeline(AlevinOpts& aopt, std::vector readFiles){ bool isOptionsOk = aut::processAlevinOpts(aopt, sopt, noTgMap, vm); if (!isOptionsOk){ + aopt.jointLog->flush(); exit(1); } From 49e1294f7925fea5f36b0ffc006d29c6caae9fe1 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sat, 6 Jun 2020 00:23:06 -0400 Subject: [PATCH 18/52] keep fld fixed within the processing of each fragment --- src/SalmonQuantify.cpp | 29 ++++++++++++++++++++++++----- src/SalmonQuantifyAlignments.cpp | 26 +++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 3a2708c86..1daf4ca49 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -255,6 +255,16 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc logCMFCache.refresh(numAssignedFragments.load(), burnedIn.load()); } + // A cache to avoid fld updates _within_ the set of alignments of a fragment + size_t maxCacheLen{salmonOpts.fragLenDistMax}; + // cache for the pmf and cmf + std::vector pmfCache(maxCacheLen+1, salmon::math::LOG_0); + std::vector cmfCache(maxCacheLen+1, salmon::math::LOG_0); + // "generation" counters to avoid using stale values + // the generation gets updated for each new fragment + std::vector pmfGen(maxCacheLen+1, 0); + std::vector cmfGen(maxCacheLen+1, 0); + uint64_t currGen{0}; int i{0}; { // Iterate over each group of alignments (a group consists of all alignments @@ -262,6 +272,7 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc // for a single read). Distribute the read's mass to the transcripts // where it potentially aligns. for (auto& alnGroup : batchHits) { + ++currGen; // If we had no alignments for this read, then skip it if (alnGroup.size() == 0) { continue; @@ -308,11 +319,10 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc uint32_t prevTxpID{0}; hasCompatibleMapping = false; + useAuxParams = ((localNumAssignedFragments + numAssignedFragments) >= + salmonOpts.numPreBurninFrags); // For each alignment of this read for (auto& aln : alnGroup.alignments()) { - - useAuxParams = ((localNumAssignedFragments + numAssignedFragments) >= - salmonOpts.numPreBurninFrags); bool considerCondProb{burnedIn or useAuxParams}; auto transcriptID = aln.transcriptID(); @@ -367,11 +377,20 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc if (flen > 0.0 and useFragLengthDist and considerCondProb) { size_t fl = flen; - double lenProb = fragLengthDist.pmf(fl); + size_t klen = (fl > maxCacheLen) ? maxCacheLen : fl; + double lenProb = (pmfGen[klen] < currGen) ? fragLengthDist.pmf(fl) : pmfCache[klen]; + pmfCache[klen] = lenProb; + pmfGen[klen] = currGen; + if (burnedIn) { /* condition fragment length prob on txp length */ + size_t rlen = (static_cast(refLength) > maxCacheLen) ? maxCacheLen : static_cast(refLength); double refLengthCM = - fragLengthDist.cmf(static_cast(refLength)); + (cmfGen[rlen] < currGen) ? + fragLengthDist.cmf(static_cast(rlen)) : cmfCache[rlen]; + cmfCache[rlen] = refLengthCM; + cmfGen[klen] = currGen; + bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); logFragProb = (computeMass) ? (lenProb - refLengthCM) diff --git a/src/SalmonQuantifyAlignments.cpp b/src/SalmonQuantifyAlignments.cpp index e6d4e5c02..763e1ccfa 100644 --- a/src/SalmonQuantifyAlignments.cpp +++ b/src/SalmonQuantifyAlignments.cpp @@ -211,6 +211,17 @@ void processMiniBatch(AlignmentLibraryT& alnLib, distribution_utils::LogCMFCache logCMFCache(&fragLengthDist, singleEndLib); + // A cache to avoid fld updates _within_ the set of alignments of a fragment + size_t maxCacheLen{salmonOpts.fragLenDistMax}; + // cache for the pmf and cmf + std::vector pmfCache(maxCacheLen+1, salmon::math::LOG_0); + std::vector cmfCache(maxCacheLen+1, salmon::math::LOG_0); + // "generation" counters to avoid using stale values + // the generation gets updated for each new fragment + std::vector pmfGen(maxCacheLen+1, 0); + std::vector cmfGen(maxCacheLen+1, 0); + uint64_t currGen{0}; + std::chrono::microseconds sleepTime(1); MiniBatchInfo>* miniBatch = nullptr; bool updateCounts = initialRound; @@ -317,7 +328,7 @@ void processMiniBatch(AlignmentLibraryT& alnLib, // alignments reported for a single read). Distribute the read's mass // proportionally dependent on the current for (auto& alnGroup : alignmentGroups) { - + ++currGen; // EQCLASS std::vector txpIDs; std::vector auxProbs; @@ -396,11 +407,20 @@ void processMiniBatch(AlignmentLibraryT& alnLib, if (flen > 0.0 and aln->isPaired() and useFragLengthDist and considerCondProb) { size_t fl = flen; - double lenProb = fragLengthDist.pmf(fl); + size_t klen = (fl > maxCacheLen) ? maxCacheLen : fl; + double lenProb = (pmfGen[klen] < currGen) ? fragLengthDist.pmf(fl) : pmfCache[klen]; + pmfCache[klen] = lenProb; + pmfGen[klen] = currGen; + if (burnedIn) { /* condition fragment length prob on txp length */ + size_t rlen = (static_cast(refLength) > maxCacheLen) ? maxCacheLen : static_cast(refLength); double refLengthCM = - fragLengthDist.cmf(static_cast(refLength)); + (cmfGen[rlen] < currGen) ? + fragLengthDist.cmf(static_cast(rlen)) : cmfCache[rlen]; + cmfCache[rlen] = refLengthCM; + cmfGen[klen] = currGen; + bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); logFragProb = (computeMass) ? (lenProb - refLengthCM) From 4edf8a8f985ba8250de34986dc239705aeed0736 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sat, 6 Jun 2020 13:48:52 -0400 Subject: [PATCH 19/52] abstract indexed versioned cache --- include/DistributionUtils.hpp | 39 +++++++++++++++++++++++++++++ src/SalmonQuantify.cpp | 40 +++++++++++++++--------------- src/SalmonQuantifyAlignments.cpp | 42 +++++++++++++++----------------- 3 files changed, 78 insertions(+), 43 deletions(-) diff --git a/include/DistributionUtils.hpp b/include/DistributionUtils.hpp index 99d684941..fc2f70abb 100644 --- a/include/DistributionUtils.hpp +++ b/include/DistributionUtils.hpp @@ -96,6 +96,45 @@ class LogCMFCache { std::vector cachedCMF_; }; +template +class VersionedValue { + public: + T val{T()}; + uint64_t gen{0}; +}; + +// A simple cache where values are retreived by their index +// and the cached items are versioned +template +class IndexedVersionedCache { + public: + + IndexedVersionedCache(size_t max_index) : + cache_(max_index+1), max_index_(max_index), current_gen_(0) {} + + inline void increment_generation() { ++current_gen_; } + + inline bool get_value(size_t index, T& v) { + size_t idx = (index > max_index_) ? max_index_ : index; + const VersionedValue& vv = cache_[idx]; + bool is_stale = vv.gen < current_gen_; + v = vv.val; + return is_stale; + } + + inline void update_value(size_t index, T v) { + size_t idx = (index > max_index_) ? max_index_ : index; + VersionedValue& vv = cache_[idx]; + vv.val = v; + vv.gen = current_gen_; + } + + private: + std::vector> cache_; + size_t max_index_{0}; + uint64_t current_gen_{0}; +}; + } // namespace distribution_utils #endif // __DISTRIBUTION_UTILS__ diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 1daf4ca49..313c1d371 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -255,16 +255,11 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc logCMFCache.refresh(numAssignedFragments.load(), burnedIn.load()); } - // A cache to avoid fld updates _within_ the set of alignments of a fragment - size_t maxCacheLen{salmonOpts.fragLenDistMax}; - // cache for the pmf and cmf - std::vector pmfCache(maxCacheLen+1, salmon::math::LOG_0); - std::vector cmfCache(maxCacheLen+1, salmon::math::LOG_0); - // "generation" counters to avoid using stale values - // the generation gets updated for each new fragment - std::vector pmfGen(maxCacheLen+1, 0); - std::vector cmfGen(maxCacheLen+1, 0); - uint64_t currGen{0}; + const size_t maxCacheLen{salmonOpts.fragLenDistMax}; + // Caches to avoid fld updates _within_ the set of alignments of a fragment + distribution_utils::IndexedVersionedCache pmfCache(maxCacheLen); + distribution_utils::IndexedVersionedCache cmfCache(maxCacheLen); + int i{0}; { // Iterate over each group of alignments (a group consists of all alignments @@ -272,7 +267,9 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc // for a single read). Distribute the read's mass to the transcripts // where it potentially aligns. for (auto& alnGroup : batchHits) { - ++currGen; + pmfCache.increment_generation(); + cmfCache.increment_generation(); + // If we had no alignments for this read, then skip it if (alnGroup.size() == 0) { continue; @@ -377,19 +374,20 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc if (flen > 0.0 and useFragLengthDist and considerCondProb) { size_t fl = flen; - size_t klen = (fl > maxCacheLen) ? maxCacheLen : fl; - double lenProb = (pmfGen[klen] < currGen) ? fragLengthDist.pmf(fl) : pmfCache[klen]; - pmfCache[klen] = lenProb; - pmfGen[klen] = currGen; + double lenProb; + if (!pmfCache.get_value(fl, lenProb)) { + lenProb = fragLengthDist.pmf(fl); + pmfCache.update_value(fl, lenProb); + } if (burnedIn) { /* condition fragment length prob on txp length */ - size_t rlen = (static_cast(refLength) > maxCacheLen) ? maxCacheLen : static_cast(refLength); - double refLengthCM = - (cmfGen[rlen] < currGen) ? - fragLengthDist.cmf(static_cast(rlen)) : cmfCache[rlen]; - cmfCache[rlen] = refLengthCM; - cmfGen[klen] = currGen; + size_t rlen = static_cast(refLength); + double refLengthCM; + if (!cmfCache.get_value(rlen, refLengthCM)) { + refLengthCM = fragLengthDist.cmf(rlen); + cmfCache.update_value(rlen, refLengthCM); + } bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); diff --git a/src/SalmonQuantifyAlignments.cpp b/src/SalmonQuantifyAlignments.cpp index 763e1ccfa..ea8a12d8a 100644 --- a/src/SalmonQuantifyAlignments.cpp +++ b/src/SalmonQuantifyAlignments.cpp @@ -211,16 +211,10 @@ void processMiniBatch(AlignmentLibraryT& alnLib, distribution_utils::LogCMFCache logCMFCache(&fragLengthDist, singleEndLib); - // A cache to avoid fld updates _within_ the set of alignments of a fragment - size_t maxCacheLen{salmonOpts.fragLenDistMax}; - // cache for the pmf and cmf - std::vector pmfCache(maxCacheLen+1, salmon::math::LOG_0); - std::vector cmfCache(maxCacheLen+1, salmon::math::LOG_0); - // "generation" counters to avoid using stale values - // the generation gets updated for each new fragment - std::vector pmfGen(maxCacheLen+1, 0); - std::vector cmfGen(maxCacheLen+1, 0); - uint64_t currGen{0}; + const size_t maxCacheLen{salmonOpts.fragLenDistMax}; + // Caches to avoid fld updates _within_ the set of alignments of a fragment + distribution_utils::IndexedVersionedCache pmfCache(maxCacheLen); + distribution_utils::IndexedVersionedCache cmfCache(maxCacheLen); std::chrono::microseconds sleepTime(1); MiniBatchInfo>* miniBatch = nullptr; @@ -328,7 +322,9 @@ void processMiniBatch(AlignmentLibraryT& alnLib, // alignments reported for a single read). Distribute the read's mass // proportionally dependent on the current for (auto& alnGroup : alignmentGroups) { - ++currGen; + pmfCache.increment_generation(); + cmfCache.increment_generation(); + // EQCLASS std::vector txpIDs; std::vector auxProbs; @@ -406,20 +402,22 @@ void processMiniBatch(AlignmentLibraryT& alnLib, if (flen > 0.0 and aln->isPaired() and useFragLengthDist and considerCondProb) { + size_t fl = flen; - size_t klen = (fl > maxCacheLen) ? maxCacheLen : fl; - double lenProb = (pmfGen[klen] < currGen) ? fragLengthDist.pmf(fl) : pmfCache[klen]; - pmfCache[klen] = lenProb; - pmfGen[klen] = currGen; - + double lenProb; + if (!pmfCache.get_value(fl, lenProb)) { + lenProb = fragLengthDist.pmf(fl); + pmfCache.update_value(fl, lenProb); + } + if (burnedIn) { /* condition fragment length prob on txp length */ - size_t rlen = (static_cast(refLength) > maxCacheLen) ? maxCacheLen : static_cast(refLength); - double refLengthCM = - (cmfGen[rlen] < currGen) ? - fragLengthDist.cmf(static_cast(rlen)) : cmfCache[rlen]; - cmfCache[rlen] = refLengthCM; - cmfGen[klen] = currGen; + size_t rlen = static_cast(refLength); + double refLengthCM; + if (!cmfCache.get_value(rlen, refLengthCM)) { + refLengthCM = fragLengthDist.cmf(rlen); + cmfCache.update_value(rlen, refLengthCM); + } bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); From b123deb575c8a4f54150a89a3f24e2bc713931b4 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sat, 6 Jun 2020 14:34:04 -0400 Subject: [PATCH 20/52] simplify code --- include/DistributionUtils.hpp | 12 ++++++++++++ src/SalmonQuantify.cpp | 14 ++++---------- src/SalmonQuantifyAlignments.cpp | 16 +++++----------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/include/DistributionUtils.hpp b/include/DistributionUtils.hpp index fc2f70abb..a75947258 100644 --- a/include/DistributionUtils.hpp +++ b/include/DistributionUtils.hpp @@ -129,6 +129,18 @@ class IndexedVersionedCache { vv.gen = current_gen_; } + template + inline T get_or_update(size_t index, F& gen_value) { + size_t idx = (index > max_index_) ? max_index_ : index; + VersionedValue& vv = cache_[idx]; + bool is_stale = vv.gen < current_gen_; + if (is_stale) { + vv.val = gen_value(idx); + vv.gen = current_gen_; + } + return vv.val; + } + private: std::vector> cache_; size_t max_index_{0}; diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 313c1d371..017c9d56b 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -257,6 +257,8 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc const size_t maxCacheLen{salmonOpts.fragLenDistMax}; // Caches to avoid fld updates _within_ the set of alignments of a fragment + auto fetchPMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.pmf(l); }; + auto fetchCMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.cmf(l); }; distribution_utils::IndexedVersionedCache pmfCache(maxCacheLen); distribution_utils::IndexedVersionedCache cmfCache(maxCacheLen); @@ -374,20 +376,12 @@ void processMiniBatch(ReadExperimentT& readExp, ForgettingMassCalculator& fmCalc if (flen > 0.0 and useFragLengthDist and considerCondProb) { size_t fl = flen; - double lenProb; - if (!pmfCache.get_value(fl, lenProb)) { - lenProb = fragLengthDist.pmf(fl); - pmfCache.update_value(fl, lenProb); - } + double lenProb = pmfCache.get_or_update(fl, fetchPMF); if (burnedIn) { /* condition fragment length prob on txp length */ size_t rlen = static_cast(refLength); - double refLengthCM; - if (!cmfCache.get_value(rlen, refLengthCM)) { - refLengthCM = fragLengthDist.cmf(rlen); - cmfCache.update_value(rlen, refLengthCM); - } + double refLengthCM = cmfCache.get_or_update(fl, fetchCMF); bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); diff --git a/src/SalmonQuantifyAlignments.cpp b/src/SalmonQuantifyAlignments.cpp index ea8a12d8a..3df955dab 100644 --- a/src/SalmonQuantifyAlignments.cpp +++ b/src/SalmonQuantifyAlignments.cpp @@ -213,6 +213,8 @@ void processMiniBatch(AlignmentLibraryT& alnLib, const size_t maxCacheLen{salmonOpts.fragLenDistMax}; // Caches to avoid fld updates _within_ the set of alignments of a fragment + auto fetchPMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.pmf(l); }; + auto fetchCMF = [&fragLengthDist](size_t l) -> double { return fragLengthDist.cmf(l); }; distribution_utils::IndexedVersionedCache pmfCache(maxCacheLen); distribution_utils::IndexedVersionedCache cmfCache(maxCacheLen); @@ -404,20 +406,12 @@ void processMiniBatch(AlignmentLibraryT& alnLib, considerCondProb) { size_t fl = flen; - double lenProb; - if (!pmfCache.get_value(fl, lenProb)) { - lenProb = fragLengthDist.pmf(fl); - pmfCache.update_value(fl, lenProb); - } - + double lenProb = pmfCache.get_or_update(fl, fetchPMF); + if (burnedIn) { /* condition fragment length prob on txp length */ size_t rlen = static_cast(refLength); - double refLengthCM; - if (!cmfCache.get_value(rlen, refLengthCM)) { - refLengthCM = fragLengthDist.cmf(rlen); - cmfCache.update_value(rlen, refLengthCM); - } + double refLengthCM = cmfCache.get_or_update(fl, fetchCMF); bool computeMass = fl < refLength and !salmon::math::isLog0(refLengthCM); From 958825007134f7e74e9be101fbb213c518fb4b9e Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Sat, 6 Jun 2020 20:53:11 -0400 Subject: [PATCH 21/52] remove temporary bool --- include/DistributionUtils.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/include/DistributionUtils.hpp b/include/DistributionUtils.hpp index a75947258..ee0a02e5f 100644 --- a/include/DistributionUtils.hpp +++ b/include/DistributionUtils.hpp @@ -133,11 +133,12 @@ class IndexedVersionedCache { inline T get_or_update(size_t index, F& gen_value) { size_t idx = (index > max_index_) ? max_index_ : index; VersionedValue& vv = cache_[idx]; - bool is_stale = vv.gen < current_gen_; - if (is_stale) { + // if the current value is stale, compute a new one and cache it + if (vv.gen < current_gen_) { vv.val = gen_value(idx); vv.gen = current_gen_; } + // return the (possibly newly) cached value return vv.val; } From c8143119884c4d445d17bd08b9724edbd6445bb0 Mon Sep 17 00:00:00 2001 From: Kevin Rue-Albrecht Date: Tue, 9 Jun 2020 15:32:15 +0100 Subject: [PATCH 22/52] typos if salmon alevin options --- src/ProgramOptionsGenerator.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ProgramOptionsGenerator.cpp b/src/ProgramOptionsGenerator.cpp index 169febc75..b3daf82c8 100644 --- a/src/ProgramOptionsGenerator.cpp +++ b/src/ProgramOptionsGenerator.cpp @@ -409,7 +409,7 @@ namespace salmon { ( "end",po::value(), "Cell-Barcodes end (5 or 3) location in the read sequence from where barcode has to" - "be extracted. (end, umiLength, barcodeLength)" + " be extracted. (end, umiLength, barcodeLength)" " should all be provided if using this option") ( "umiLength",po::value(), @@ -417,7 +417,7 @@ namespace salmon { " should all be provided if using this option") ( "barcodeLength",po::value(), - "umi length Parameter for unknown protocol. (end, umiLength, barcodeLength)" + "barcode length Parameter for unknown protocol. (end, umiLength, barcodeLength)" " should all be provided if using this option") ( "noem",po::bool_switch()->default_value(alevin::defaults::noEM), From a3a82e6b489047ea0e6226865cbba31784b214c1 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Tue, 9 Jun 2020 13:34:43 -0400 Subject: [PATCH 23/52] fixed seed --- src/FASTAParser.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/FASTAParser.cpp b/src/FASTAParser.cpp index 338198ab5..926e66efb 100644 --- a/src/FASTAParser.cpp +++ b/src/FASTAParser.cpp @@ -42,8 +42,8 @@ void FASTAParser::populateTargets(std::vector& refs, constexpr char bases[] = {'A', 'C', 'G', 'T'}; // Create a random uniform distribution - std::random_device rd; - std::default_random_engine eng(rd()); + constexpr const uint64_t randseed{271828}; + std::default_random_engine eng(randseed); std::uniform_int_distribution<> dis(0, 3); uint64_t numNucleotidesReplaced{0}; From 90fe7ccc1667f3f18d66834d65cbfdc107f6647b Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Wed, 10 Jun 2020 19:15:38 -0400 Subject: [PATCH 24/52] correcting variance bias --- src/CollapsedCellOptimizer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/CollapsedCellOptimizer.cpp b/src/CollapsedCellOptimizer.cpp index 4ded743e3..a7844c1da 100644 --- a/src/CollapsedCellOptimizer.cpp +++ b/src/CollapsedCellOptimizer.cpp @@ -542,6 +542,7 @@ bool runBootstraps(size_t numGenes, double meanAlpha = mean[i] / numBootstraps; geneAlphas[i] = meanAlpha; variance[i] = (squareMean[i]/numBootstraps) - (meanAlpha*meanAlpha); + variance[i] *= (numBootstraps / static_cast(numBootstraps - 1)); } return true; From 617070b5821b4bacff590f2f3600efd8c2d52295 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 11 Jun 2020 15:48:35 -0400 Subject: [PATCH 25/52] allow writing out of decoy alignments with --- include/SalmonMappingUtils.hpp | 188 +++++++++++++++++++++++++++------ src/SalmonAlevin.cpp | 62 +++++------ src/SalmonQuantify.cpp | 146 ++++++++++++------------- 3 files changed, 260 insertions(+), 136 deletions(-) diff --git a/include/SalmonMappingUtils.hpp b/include/SalmonMappingUtils.hpp index e1d9b3d66..82aa93e82 100644 --- a/include/SalmonMappingUtils.hpp +++ b/include/SalmonMappingUtils.hpp @@ -56,23 +56,77 @@ #include "pufferfish/ksw2pp/KSW2Aligner.hpp" #include "pufferfish/metro/metrohash64.h" #include "pufferfish/SelectiveAlignmentUtils.hpp" +#include "pufferfish/chobo/small_vector.hpp" +#include "parallel_hashmap/phmap.h" namespace salmon { namespace mapping_utils { using MateStatus = pufferfish::util::MateStatus; + constexpr const int32_t invalid_score_ = std::numeric_limits::min(); + constexpr const int32_t invalid_index_ = std::numeric_limits::min(); + constexpr const size_t static_vec_size = 32; + + class MappingScoreInfo { + public: + MappingScoreInfo() + : bestScore(invalid_score_), secondBestScore(invalid_score_), + bestDecoyScore(invalid_score_), decoyThresh(1.0), collect_decoy_info_(false) {} + + MappingScoreInfo(double decoyThreshIn) : MappingScoreInfo() { + decoyThresh = decoyThreshIn; + } + + void collect_decoys(bool do_collect) { collect_decoy_info_ = do_collect; } + bool collect_decoys() const { return collect_decoy_info_; } + + // clear everything but the decoy threshold + void clear(size_t num_hits) { + bestScore = invalid_score_; + secondBestScore = invalid_score_; + bestDecoyScore = invalid_score_; + scores_.clear(); scores_.resize(num_hits, invalid_score_); + bestScorePerTranscript_.clear(); + perm_.clear(); + // NOTE: we do _not_ reset decoyThresh here + if (collect_decoy_info_) { + bool revert_to_static = best_decoy_hits.size() > static_vec_size; + best_decoy_hits.clear(); + if (revert_to_static) { best_decoy_hits.revert_to_static(); } + } + } + + bool haveOnlyDecoyMappings() const { + // if the best non-decoy mapping has score less than decoyThresh * + // bestDecoyScore and if the bestDecoyScore is a valid value, then we + // have no valid non-decoy mappings. + return (bestScore < static_cast(decoyThresh * bestDecoyScore)) and + (bestDecoyScore > std::numeric_limits::min()); + } + + inline bool update_decoy_mappings(int32_t hitScore, size_t idx, uint32_t tid) { + const bool better_score = hitScore > bestDecoyScore; + if (hitScore > bestDecoyScore) { + bestDecoyScore = hitScore; + if (collect_decoy_info_) { + best_decoy_hits.clear(); + best_decoy_hits.push_back(std::make_pair(static_cast(idx), static_cast(tid))); + } + } else if (collect_decoy_info_ and (hitScore == bestDecoyScore)){ + best_decoy_hits.push_back(std::make_pair(static_cast(idx), static_cast(tid))); + } + return better_score; + } - struct MappingScoreInfo { int32_t bestScore; int32_t secondBestScore; int32_t bestDecoyScore; double decoyThresh; - bool haveOnlyDecoyMappings() const { - // if the best non-decoy mapping has score less than decoyThresh * bestDecoyScore - // and if the bestDecoyScore is a valid value, then we have no valid non-decoy mappings. - return (bestScore < static_cast(decoyThresh * bestDecoyScore)) and - (bestDecoyScore > std::numeric_limits::min()); - } + chobo::small_vector> best_decoy_hits; + bool collect_decoy_info_; + std::vector scores_; + phmap::flat_hash_map> bestScorePerTranscript_; + std::vector> perm_; }; template @@ -136,15 +190,8 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, size_t idx, const std::vector& transcripts, int32_t invalidScore, - salmon::mapping_utils::MappingScoreInfo& msi, - /* - int32_t& bestScore, - int32_t& secondBestScore, - int32_t& bestDecoyScore, - */ - std::vector& scores, - phmap::flat_hash_map>& bestScorePerTranscript, - std::vector>& perm) { + salmon::mapping_utils::MappingScoreInfo& msi) { + auto& scores = msi.scores_; scores[idx] = hitScore; auto& t = transcripts[tid]; bool isDecoy = t.isDecoy(); @@ -153,10 +200,13 @@ inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, siz //if (hitScore < decoyCutoff or (hitScore == invalidScore)) { } if (isDecoy) { - // if this is a decoy and its score is at least as good - // as the bestDecoyScore, then update that and go to the next mapping - msi.bestDecoyScore = std::max(hitScore, msi.bestDecoyScore); - return; + // NOTE: decide here if we need to process any of this if the + // current score is < the best (non-decoy) score. I think not. + + // if this is a decoy and its score is better than the best decoy score + bool did_update = msi.update_decoy_mappings(hitScore, idx, tid); + (void)did_update; + return; } else if (hitScore < decoyCutoff or (hitScore == invalidScore)) { // if the current score is to a valid target but doesn't // exceed the necessary decoy threshold, then skip it. @@ -164,6 +214,8 @@ inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, siz } // otherwise, we have a "high-scoring" hit to a non-decoy + auto& perm = msi.perm_; + auto& bestScorePerTranscript = msi.bestScorePerTranscript_; // removing duplicate hits from a read to the same transcript auto it = bestScorePerTranscript.find(tid); if (it == bestScorePerTranscript.end()) { @@ -187,15 +239,12 @@ inline void updateRefMappings(uint32_t tid, int32_t hitScore, bool isCompat, siz msi.secondBestScore = msi.bestScore; msi.bestScore = hitScore; } - //bestScore = (hitScore > bestScore) ? hitScore : bestScore; perm.push_back(std::make_pair(idx, tid)); } inline void filterAndCollectAlignments( std::vector& jointHits, - const std::vector& scores, - std::vector>& perm, uint32_t readLen, uint32_t mateLen, bool singleEnd, @@ -203,13 +252,7 @@ inline void filterAndCollectAlignments( bool hardFilter, double scoreExp, double minAlnProb, - // intentionally passing by value below --- come back - // and make sure it's necessary - salmon::mapping_utils::MappingScoreInfo msi, - /*int32_t bestScore, - int32_t secondBestScore, - int32_t bestDecoyScore, - */ + salmon::mapping_utils::MappingScoreInfo& msi, std::vector& jointAlignments) { auto invalidScore = std::numeric_limits::min(); @@ -218,6 +261,8 @@ inline void filterAndCollectAlignments( int32_t decoyThreshold = static_cast(msi.decoyThresh * msi.bestDecoyScore); //auto filterScore = (bestDecoyScore < secondBestScore) ? secondBestScore : bestDecoyScore; + auto& scores = msi.scores_; + auto& perm = msi.perm_; // throw away any pairs for which we should not produce valid alignments : // ====== @@ -300,6 +345,89 @@ inline void filterAndCollectAlignments( // done moving our alinged / score jointMEMs over to QuasiAlignment objects } + +inline void filterAndCollectAlignmentsDecoy( + std::vector& jointHits, + uint32_t readLen, + uint32_t mateLen, + bool singleEnd, + bool tryAlign, + bool hardFilter, + double scoreExp, + double minAlnProb, + salmon::mapping_utils::MappingScoreInfo& msi, + std::vector& jointAlignments) { +// NOTE: this function should only be called in the case that we have valid decoy mappings to report. +// Currently, this happens only when there are *no valid non-decoy* mappings. +// Further, this function will only add equally *best* decoy mappings to the output jointAlignments object +// regardless of the the status of hardFilter (i.e. no sub-optimal decoy mappings will be reported). +(void) hardFilter; +(void) minAlnProb; +double estAlnProb = 1.0; //std::exp(-scoreExp * 0.0); +for (auto& idxTxp : msi.best_decoy_hits) { + int32_t ctr = idxTxp.first; + int32_t tid = idxTxp.second; + auto& jointHit = jointHits[ctr]; + + if (singleEnd or jointHit.isOrphan()) { + readLen = jointHit.isLeftAvailable() ? readLen : mateLen; + jointAlignments.emplace_back( + tid, // reference id + jointHit.orphanClust()->getTrFirstHitPos(), // reference pos + jointHit.orphanClust()->isFw, // fwd direction + readLen, // read length + jointHit.orphanClust()->cigar, // cigar string + jointHit.fragmentLen, // fragment length + false); + auto& qaln = jointAlignments.back(); + // NOTE : score should not be filled in from a double + qaln.score = !tryAlign + ? static_cast(jointHit.orphanClust()->coverage) + : jointHit.alignmentScore; + qaln.estAlnProb(estAlnProb); + // NOTE : wth is numHits? + qaln.numHits = + static_cast(jointHits.size()); // orphanClust()->coverage; + qaln.mateStatus = jointHit.mateStatus; + if (singleEnd) { + qaln.mateLen = readLen; + qaln.mateCigar.clear(); + qaln.matePos = 0; + qaln.mateIsFwd = true; + qaln.mateScore = 0; + qaln.mateStatus = MateStatus::SINGLE_END; + } + } else { + jointAlignments.emplace_back( + tid, // reference id + jointHit.leftClust->getTrFirstHitPos(), // reference pos + jointHit.leftClust->isFw, // fwd direction + readLen, // read length + jointHit.leftClust->cigar, // cigar string + jointHit.fragmentLen, // fragment length + true); // properly paired + // Fill in the mate info + auto& qaln = jointAlignments.back(); + qaln.mateLen = mateLen; + qaln.mateCigar = jointHit.rightClust->cigar; + qaln.matePos = + static_cast(jointHit.rightClust->getTrFirstHitPos()); + qaln.mateIsFwd = jointHit.rightClust->isFw; + qaln.mateStatus = MateStatus::PAIRED_END_PAIRED; + // NOTE : wth is numHits? + qaln.numHits = static_cast(jointHits.size()); + // NOTE : score should not be filled in from a double + qaln.score = !tryAlign ? static_cast(jointHit.leftClust->coverage) + : jointHit.alignmentScore; + qaln.estAlnProb(estAlnProb); + qaln.mateScore = !tryAlign + ? static_cast(jointHit.rightClust->coverage) + : jointHit.mateAlignmentScore; + } +} // end for over best decoy hits +} + + } // namespace mapping_utils } // namespace salmon diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 7d5d6f5de..d615a2627 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -449,7 +449,6 @@ void processReadsQuasi( std::vector jointHits; PairedAlignmentFormatter formatter(qidx); pufferfish::util::QueryCache qc; - phmap::flat_hash_map> bestScorePerTranscript; bool mimicStrictBT2 = salmonOpts.mimicStrictBT2; bool mimicBT2 = salmonOpts.mimicBT2; @@ -467,9 +466,6 @@ void processReadsQuasi( ////////////////////// // NOTE: validation mapping based new parameters std::string rc1; rc1.reserve(300); - // will hold the permutation to use to put the transcripts in order - std::vector> perm; - //std::vector alnCache; alnCache.reserve(15); AlnCacheMap alnCache; alnCache.reserve(16); /* @@ -485,6 +481,12 @@ void processReadsQuasi( size_t numMappingsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; + + salmon::mapping_utils::MappingScoreInfo msi(decoyThreshold); + // we only collect detailed decoy information if we will be + // writing output to SAM. + msi.collect_decoys(writeQuasimappings); + std::string readSubSeq; ////////////////////// @@ -514,7 +516,6 @@ void processReadsQuasi( jointHitGroup.clearAlignments(); auto& jointAlignments= jointHitGroup.alignments(); - perm.clear(); hits.clear(); jointHits.clear(); memCollector.clear(); @@ -650,20 +651,8 @@ void processReadsQuasi( // adding validate mapping code if (tryAlign and !jointHits.empty()) { puffaligner.clear(); - bestScorePerTranscript.clear(); - - //auto* r1 = readSubSeq.data(); - //auto l1 = static_cast(readSubSeq.length()); - - // the best scores start out as invalid - /* - int32_t bestScore = invalidScore; - int32_t secondBestScore = invalidScore; - int32_t bestDecoyScore = invalidScore; - */ - salmon::mapping_utils::MappingScoreInfo msi = {invalidScore, invalidScore, invalidScore, decoyThreshold}; + msi.clear(jointHits.size()); - std::vector scores(jointHits.size(), invalidScore); size_t idx{0}; bool isMultimapping = (jointHits.size() > 1); @@ -685,18 +674,13 @@ void processReadsQuasi( (!jointHit.orphanClust()->isFw and (expectedLibraryFormat.strandedness == ReadStrandedness::SA)); salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, - msi, - //bestScore, secondBestScore, bestDecoyScore, - scores, bestScorePerTranscript, perm); + msi); ++idx; } - //bool bestHitDecoy = (msi.bestScore < msi.bestDecoyScore); bool bestHitDecoy = msi.haveOnlyDecoyMappings(); if (msi.bestScore > invalidScore and !bestHitDecoy) { salmon::mapping_utils::filterAndCollectAlignments(jointHits, - scores, - perm, readSubSeq.length(), readSubSeq.length(), false, // true for single-end false otherwise @@ -705,11 +689,6 @@ void processReadsQuasi( salmonOpts.scoreExp, salmonOpts.minAlnProb, msi, - /* - bestScore, - secondBestScore, - bestDecoyScore, - */ jointAlignments); if (!jointAlignments.empty()) { mapType = salmon::utils::MappingType::SINGLE_MAPPED; @@ -717,17 +696,32 @@ void processReadsQuasi( } else { numDecoyFrags += bestHitDecoy ? 1 : 0; ++numDropped; - jointHitGroup.clearAlignments(); mapType = (bestHitDecoy) ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; + if (bestHitDecoy) { + salmon::mapping_utils::filterAndCollectAlignments( + jointHits, readSubSeq.length(), + readSubSeq.length(), + false, // true for single-end false otherwise + tryAlign, hardFilter, salmonOpts.scoreExp, + salmonOpts.minAlnProb, msi, + jointAlignments); + } else { + jointHitGroup.clearAlignments(); + } } } //end-if validate mapping if (writeQuasimappings) { writeAlignmentsToStream(rp, formatter, jointAlignments, sstream, true, true); - /* - rapmap::utils::writeAlignmentsToStream(rp, formatter, - hctr, jointHits, sstream); - */ + } + + // We've kept decoy aignments around to this point so that we can + // potentially write these alignments to the SAM file. However, if + // we got to this point and only have decoy mappings, then clear the + // mappings here because none of the procesing below is relevant for + // decoys. + if (mapType == salmon::utils::MappingType::DECOY) { + jointHitGroup.clearAlignments(); } if (writeUnmapped and mapType != salmon::utils::MappingType::SINGLE_MAPPED) { diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 017c9d56b..9392a1b52 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -823,6 +823,7 @@ void processReads( //******* Setting up pufferfish mapping constexpr const int32_t invalidScore = std::numeric_limits::min(); + constexpr const int32_t invalidIndex = std::numeric_limits::min(); MemCollector memCollector(qidx); ksw2pp::KSW2Aligner aligner; pufferfish::util::AlignmentConfig aconf; @@ -842,7 +843,6 @@ void processReads( bool noDovetail = !salmonOpts.allowDovetail; bool useChainingHeuristic = !salmonOpts.disableChainingHeuristic; size_t numOrphansRescued{0}; - phmap::flat_hash_map> bestScorePerTranscript; uint64_t firstDecoyIndex = qidx->firstDecoyIndex(); //******* @@ -864,15 +864,16 @@ void processReads( } */ - // will hold the permutation to use to put the transcripts in order - std::vector> perm; - size_t numMappingsDropped{0}; size_t numFragsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; uint32_t readLen{0}, mateLen{0}, totLen{0}; + salmon::mapping_utils::MappingScoreInfo msi(decoyThreshold); + // we only collect detailed decoy information if we will be + // writing output to SAM. + msi.collect_decoys(writeQuasimappings); auto rg = parser->getReadGroup(); while (parser->refill(rg)) { @@ -907,7 +908,6 @@ void processReads( ++hctr.numReads; - perm.clear(); jointHits.clear(); leftHits.clear(); rightHits.clear(); @@ -1025,7 +1025,6 @@ void processReads( } } - // TODO: PF_INTEGRATION // NOTE : Under our new definition of orphans, alignments of read ends // can be orphans even if the other read end aligns to the same reference. // It only matters that the alignments were not paired. Thus, it is possible @@ -1077,17 +1076,8 @@ void processReads( if (tryAlign and !jointHits.empty()) { // clear the aligner for this read puffaligner.clear(); - bestScorePerTranscript.clear(); - - // the best scores start out as invalid - /* - int32_t bestScore = invalidScore; - int32_t secondBestScore = invalidScore; - int32_t bestDecoyScore = invalidScore; - */ - - salmon::mapping_utils::MappingScoreInfo msi = {invalidScore, invalidScore, invalidScore, decoyThreshold}; - std::vector scores(jointHits.size(), invalidScore); + msi.clear(jointHits.size()); + size_t idx{0}; bool isMultimapping = (jointHits.size() > 1); @@ -1179,17 +1169,13 @@ void processReads( **/ // end of alternative compat - salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, msi, - //bestScore, secondBestScore, bestDecoyScore, - scores, bestScorePerTranscript, perm); + salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, msi); ++idx; } bool bestHitDecoy = msi.haveOnlyDecoyMappings(); if (msi.bestScore > invalidScore and !bestHitDecoy) { salmon::mapping_utils::filterAndCollectAlignments(jointHits, - scores, - perm, readLen, mateLen, false, // true for single-end false otherwise @@ -1198,11 +1184,6 @@ void processReads( salmonOpts.scoreExp, salmonOpts.minAlnProb, msi, - /* - bestScore, - secondBestScore, - bestDecoyScore, - */ jointAlignments); // if we have alignments if (!jointAlignments.empty()) { @@ -1229,14 +1210,25 @@ void processReads( mapType = bestHitDecoy ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; numDecoyFrags += bestHitDecoy ? 1 : 0; ++numFragsDropped; - jointAlignmentGroup.clearAlignments(); // TODO: Create alignment objects for the decoys so that we can write // decoy alignments to file - /** + if (bestHitDecoy) { - salmon::mapping_utils::filterAndCollectAlignmentsDecoy(...); + salmon::mapping_utils::filterAndCollectAlignmentsDecoy(jointHits, + readLen, + mateLen, + false, // true for single-end false otherwise + tryAlign, + hardFilter, + salmonOpts.scoreExp, + salmonOpts.minAlnProb, + msi, + jointAlignments); + } else { + jointAlignmentGroup.clearAlignments(); } - **/ + + //jointAlignmentGroup.clearAlignments(); } } else if (isPaired and noDovetail) { salmonOpts.jointLog->critical("This code path is not yet implemented!"); @@ -1255,6 +1247,24 @@ void processReads( jointAlignments.end()); } + if (writeQuasimappings) { + writeAlignmentsToStream(rp, formatter, + jointAlignments, + sstream, + true, // write orphans + true // transcript ID's already decoded (taking care of short refs) + ); + } + + // We've kept decoy aignments around to this point so that we can + // potentially write these alignments to the SAM file. However, if + // we got to this point and only have decoy mappings, then clear the + // mappings here because none of the procesing below is relevant for + // decoys. + if (mapType == salmon::utils::MappingType::DECOY) { + jointAlignmentGroup.clearAlignments(); + } + bool needBiasSample = salmonOpts.biasCorrect; std::uniform_int_distribution<> dis(0, jointAlignments.size()); @@ -1372,14 +1382,6 @@ void processReads( } } - if (writeQuasimappings) { - writeAlignmentsToStream(rp, formatter, - jointAlignments, - sstream, - true, // write orphans - true // transcript ID's already decoded (taking care of short refs) - ); - } } else { // This read was completely unmapped. mapType = salmon::utils::MappingType::UNMAPPED; @@ -1553,6 +1555,7 @@ void processReads( //******* Setting up pufferfish mapping constexpr const int32_t invalidScore = std::numeric_limits::min(); + constexpr const int32_t invalidIndex = std::numeric_limits::min(); MemCollector memCollector(qidx); ksw2pp::KSW2Aligner aligner; pufferfish::util::AlignmentConfig aconf; @@ -1565,7 +1568,6 @@ void processReads( std::vector jointHits; PairedAlignmentFormatter formatter(qidx); pufferfish::util::QueryCache qc; - phmap::flat_hash_map> bestScorePerTranscript; bool mimicStrictBT2 = salmonOpts.mimicStrictBT2; bool mimicBT2 = salmonOpts.mimicBT2; @@ -1586,14 +1588,16 @@ void processReads( std::string rc1; rc1.reserve(300); - // will hold the permutation to use to put the transcripts in order - std::vector> perm; - size_t numMappingsDropped{0}; size_t numFragsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; + salmon::mapping_utils::MappingScoreInfo msi(decoyThreshold); + // we only collect detailed decoy information if we will be + // writing output to SAM. + msi.collect_decoys(writeQuasimappings); + //std::vector alnCache; alnCache.reserve(15); AlnCacheMap alnCache; alnCache.reserve(16); @@ -1622,7 +1626,6 @@ void processReads( auto& jointAlignments = jointHitGroup.alignments(); mapType = salmon::utils::MappingType::UNMAPPED; - perm.clear(); hits.clear(); jointHits.clear(); memCollector.clear(); @@ -1686,17 +1689,8 @@ void processReads( // clear the aligner for this read puffaligner.clear(); - bestScorePerTranscript.clear(); - - // the best scores start out as invalid - /* - int32_t bestScore = invalidScore; - int32_t secondBestScore = invalidScore; - int32_t bestDecoyScore = invalidScore; - */ - salmon::mapping_utils::MappingScoreInfo msi = {invalidScore, invalidScore, invalidScore, decoyThreshold}; - - std::vector scores(jointHits.size(), invalidScore); + msi.clear(jointHits.size()); + size_t idx{0}; bool isMultimapping = (jointHits.size() > 1); @@ -1710,19 +1704,14 @@ void processReads( (jointHit.orphanClust()->isFw and (expectedLibraryFormat.strandedness == ReadStrandedness::S)) or (!jointHit.orphanClust()->isFw and (expectedLibraryFormat.strandedness == ReadStrandedness::A)); - salmon::mapping_utils::updateRefMappings(tid, hitScore, isCompat, idx, transcripts, invalidScore, - msi, - //bestScore, secondBestScore, bestDecoyScore, - scores, bestScorePerTranscript, perm); + salmon::mapping_utils::updateRefMappings( + tid, hitScore, isCompat, idx, transcripts, invalidScore, msi); ++idx; } - //bool bestHitDecoy = (msi.bestScore < msi.bestDecoyScore); bool bestHitDecoy = msi.haveOnlyDecoyMappings(); if (msi.bestScore > invalidScore and !bestHitDecoy) { salmon::mapping_utils::filterAndCollectAlignments(jointHits, - scores, - perm, readLen, readLen, true, // true for single-end false otherwise @@ -1731,11 +1720,6 @@ void processReads( salmonOpts.scoreExp, salmonOpts.minAlnProb, msi, - /* - bestScore, - secondBestScore, - bestDecoyScore, - */ jointAlignments); // if we have any alignments, then they are // just single mapped. @@ -1747,10 +1731,32 @@ void processReads( mapType = (bestHitDecoy) ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; numDecoyFrags += bestHitDecoy ? 1 : 0; ++numFragsDropped; - jointHitGroup.clearAlignments(); + if (bestHitDecoy) { + salmon::mapping_utils::filterAndCollectAlignmentsDecoy( + jointHits, readLen, readLen, + true, // true for single-end false otherwise + tryAlign, hardFilter, salmonOpts.scoreExp, + salmonOpts.minAlnProb, msi, + jointAlignments); + } else { + jointHitGroup.clearAlignments(); + } } } + if (writeQuasimappings) { + writeAlignmentsToStreamSingle(rp, formatter, jointAlignments, sstream, false, true); + } + + // We've kept decoy aignments around to this point so that we can + // potentially write these alignments to the SAM file. However, if + // we got to this point and only have decoy mappings, then clear the + // mappings here because none of the procesing below is relevant for + // decoys. + if (mapType == salmon::utils::MappingType::DECOY) { + jointHitGroup.clearAlignments(); + } + bool needBiasSample = salmonOpts.biasCorrect; std::uniform_int_distribution<> dis(0, jointAlignments.size()); @@ -1807,10 +1813,6 @@ void processReads( } } - if (writeQuasimappings) { - writeAlignmentsToStreamSingle(rp, formatter, jointAlignments, sstream, false, true); - } - if (writeUnmapped and mapType != salmon::utils::MappingType::SINGLE_MAPPED) { // If we have no mappings --- then there's nothing to do // unless we're outputting names for un-mapped reads From 6b47eb18b9505af79f608c44123c8ce89aeaba3f Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 11 Jun 2020 23:26:12 -0400 Subject: [PATCH 26/52] update parallel hashmap --- include/parallel_hashmap/phmap.h | 456 +++++------------- include/parallel_hashmap/phmap_base.h | 552 ++++++++++++++++------ include/parallel_hashmap/phmap_bits.h | 208 ++++---- include/parallel_hashmap/phmap_config.h | 22 +- include/parallel_hashmap/phmap_fwd_decl.h | 42 +- include/parallel_hashmap/phmap_utils.h | 78 ++- 6 files changed, 761 insertions(+), 597 deletions(-) diff --git a/include/parallel_hashmap/phmap.h b/include/parallel_hashmap/phmap.h index ffe68b22c..7c1cd95a0 100644 --- a/include/parallel_hashmap/phmap.h +++ b/include/parallel_hashmap/phmap.h @@ -34,6 +34,23 @@ // limitations under the License. // --------------------------------------------------------------------------- +#ifdef _MSC_VER + #pragma warning(push) + + #pragma warning(disable : 4127) // conditional expression is constant + #pragma warning(disable : 4324) // structure was padded due to alignment specifier + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4623) // default constructor was implicitly defined as deleted + #pragma warning(disable : 4625) // copy constructor was implicitly defined as deleted + #pragma warning(disable : 4626) // assignment operator was implicitly defined as deleted + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion + #pragma warning(disable : 4820) // '6' bytes padding added after data member + #pragma warning(disable : 4868) // compiler may not enforce left-to-right evaluation order in braced initializer list + #pragma warning(disable : 5027) // move assignment operator was implicitly defined as deleted + #pragma warning(disable : 5045) // Compiler will insert Spectre mitigation for memory load if /Qspectre switch specified +#endif + #include #include #include @@ -45,10 +62,11 @@ #include #include #include +#include +#include "phmap_fwd_decl.h" #include "phmap_utils.h" #include "phmap_base.h" -#include "phmap_fwd_decl.h" #if PHMAP_HAVE_STD_STRING_VIEW #include @@ -77,7 +95,7 @@ class probe_seq offset_ &= mask_; } // 0-based probe index. The i-th probe in the probe sequence. - size_t index() const { return index_; } + size_t getindex() const { return index_; } private: size_t mask_; @@ -267,7 +285,7 @@ inline size_t H1(size_t hash, const ctrl_t* ) { #endif -inline ctrl_t H2(size_t hash) { return hash & 0x7F; } +inline ctrl_t H2(size_t hash) { return (ctrl_t)(hash & 0x7F); } inline bool IsEmpty(ctrl_t c) { return c == kEmpty; } inline bool IsFull(ctrl_t c) { return c >= 0; } @@ -276,6 +294,11 @@ inline bool IsEmptyOrDeleted(ctrl_t c) { return c < kSentinel; } #if PHMAP_HAVE_SSE2 +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4365) // conversion from 'int' to 'T', signed/unsigned mismatch +#endif + // -------------------------------------------------------------------------- // https://github.com/abseil/abseil-cpp/issues/209 // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87853 @@ -289,7 +312,7 @@ inline __m128i _mm_cmpgt_epi8_fixed(__m128i a, __m128i b) { #pragma GCC diagnostic ignored "-Woverflow" if (std::is_unsigned::value) { - const __m128i mask = _mm_set1_epi8(0x80); + const __m128i mask = _mm_set1_epi8(static_cast(0x80)); const __m128i diff = _mm_subs_epi8(b, a); return _mm_cmpeq_epi8(_mm_and_si128(diff, mask), mask); } @@ -312,7 +335,7 @@ struct GroupSse2Impl // Returns a bitmask representing the positions of slots that match hash. // ---------------------------------------------------------------------- BitMask Match(h2_t hash) const { - auto match = _mm_set1_epi8(hash); + auto match = _mm_set1_epi8((char)hash); return BitMask( _mm_movemask_epi8(_mm_cmpeq_epi8(match, ctrl))); } @@ -361,6 +384,11 @@ struct GroupSse2Impl __m128i ctrl; }; + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + #endif // PHMAP_HAVE_SSE2 // -------------------------------------------------------------------------- @@ -404,7 +432,7 @@ struct GroupPortableImpl uint32_t CountLeadingEmptyOrDeleted() const { constexpr uint64_t gaps = 0x00FEFEFEFEFEFEFEULL; - return (TrailingZeros(((~ctrl & (ctrl >> 7)) | gaps) + 1) + 7) >> 3; + return (uint32_t)((TrailingZeros(((~ctrl & (ctrl >> 7)) | gaps) + 1) + 7) >> 3); } void ConvertSpecialToEmptyAndFullToDeleted(ctrl_t* dst) const { @@ -468,9 +496,12 @@ inline size_t CapacityToGrowth(size_t capacity) { assert(IsValidCapacity(capacity)); // `capacity*7/8` - if (Group::kWidth == 8 && capacity == 7) { - // x-x/8 does not work when x==7. - return 6; + PHMAP_IF_CONSTEXPR (Group::kWidth == 8) { + if (capacity == 7) + { + // x-x/8 does not work when x==7. + return 6; + } } return capacity - capacity / 8; } @@ -482,9 +513,12 @@ inline size_t CapacityToGrowth(size_t capacity) inline size_t GrowthToLowerboundCapacity(size_t growth) { // `growth*8/7` - if (Group::kWidth == 8 && growth == 7) { - // x+(x-1)/7 does not work when x==7. - return 8; + PHMAP_IF_CONSTEXPR (Group::kWidth == 8) { + if (growth == 7) + { + // x+(x-1)/7 does not work when x==7. + return 8; + } } return growth + static_cast((static_cast(growth) - 1) / 7); } @@ -642,78 +676,6 @@ DecomposePairImpl(F&& f, std::pair, V> p) { } // namespace memory_internal -// Helper functions for asan and msan. -// ---------------------------------------------------------------------------- -inline void SanitizerPoisonMemoryRegion(const void* m, size_t s) { -#ifdef ADDRESS_SANITIZER - ASAN_POISON_MEMORY_REGION(m, s); -#endif -#ifdef MEMORY_SANITIZER - __msan_poison(m, s); -#endif - (void)m; - (void)s; -} - -inline void SanitizerUnpoisonMemoryRegion(const void* m, size_t s) { -#ifdef ADDRESS_SANITIZER - ASAN_UNPOISON_MEMORY_REGION(m, s); -#endif -#ifdef MEMORY_SANITIZER - __msan_unpoison(m, s); -#endif - (void)m; - (void)s; -} - -template -inline void SanitizerPoisonObject(const T* object) { - SanitizerPoisonMemoryRegion(object, sizeof(T)); -} - -template -inline void SanitizerUnpoisonObject(const T* object) { - SanitizerUnpoisonMemoryRegion(object, sizeof(T)); -} - -// ---------------------------------------------------------------------------- -// Allocates at least n bytes aligned to the specified alignment. -// Alignment must be a power of 2. It must be positive. -// -// Note that many allocators don't honor alignment requirements above certain -// threshold (usually either alignof(std::max_align_t) or alignof(void*)). -// Allocate() doesn't apply alignment corrections. If the underlying allocator -// returns insufficiently alignment pointer, that's what you are going to get. -// ---------------------------------------------------------------------------- -template -void* Allocate(Alloc* alloc, size_t n) { - static_assert(Alignment > 0, ""); - assert(n && "n must be positive"); - struct alignas(Alignment) M {}; - using A = typename phmap::allocator_traits::template rebind_alloc; - using AT = typename phmap::allocator_traits::template rebind_traits; - A mem_alloc(*alloc); - void* p = AT::allocate(mem_alloc, (n + sizeof(M) - 1) / sizeof(M)); - assert(reinterpret_cast(p) % Alignment == 0 && - "allocator does not respect alignment"); - return p; -} - -// ---------------------------------------------------------------------------- -// The pointer must have been previously obtained by calling -// Allocate(alloc, n). -// ---------------------------------------------------------------------------- -template -void Deallocate(Alloc* alloc, void* p, size_t n) { - static_assert(Alignment > 0, ""); - assert(n && "n must be positive"); - struct alignas(Alignment) M {}; - using A = typename phmap::allocator_traits::template rebind_alloc; - using AT = typename phmap::allocator_traits::template rebind_traits; - A mem_alloc(*alloc); - AT::deallocate(mem_alloc, static_cast(p), - (n + sizeof(M) - 1) / sizeof(M)); -} // ---------------------------------------------------------------------------- // R A W _ H A S H _ S E T @@ -1202,6 +1164,8 @@ class raw_hash_set // compared to destruction of the elements of the container. So we pick the // largest bucket_count() threshold for which iteration is still fast and // past that we simply deallocate the array. + if (empty()) + return; if (capacity_ > 127) { destroy_slots(); } else if (capacity_) { @@ -1532,6 +1496,14 @@ class raw_hash_set } } +#ifndef PHMAP_NON_DETERMINISTIC + template + bool dump(OutputArchive&) const; + + template + bool load(InputArchive&); +#endif + void rehash(size_t n) { if (n == 0 && capacity_ == 0) return; if (n == 0 && size_ == 0) { @@ -1541,7 +1513,7 @@ class raw_hash_set } // bitor is a faster way of doing `max` here. We will round up to the next // power-of-2-minus-1, so bitor is good enough. - auto m = NormalizeCapacity(n | GrowthToLowerboundCapacity(size())); + auto m = NormalizeCapacity((std::max)(n, size())); // n == 0 unconditionally rehashes as per the standard. if (n == 0 || m > capacity_) { resize(m); @@ -1561,7 +1533,7 @@ class raw_hash_set // s.count("abc"); template size_t count(const key_arg& key) const { - return find(key) == end() ? 0 : 1; + return find(key) == end() ? size_t(0) : size_t(1); } // Issues CPU prefetch instructions for the memory needed to find or insert @@ -1595,11 +1567,11 @@ class raw_hash_set auto seq = probe(hash); while (true) { Group g{ctrl_ + seq.offset()}; - for (int i : g.Match(H2(hash))) { + for (int i : g.Match((h2_t)H2(hash))) { if (PHMAP_PREDICT_TRUE(PolicyTraits::apply( EqualElement{key, eq_ref()}, - PolicyTraits::element(slots_ + seq.offset(i))))) - return iterator_at(seq.offset(i)); + PolicyTraits::element(slots_ + seq.offset((size_t)i))))) + return iterator_at(seq.offset((size_t)i)); } if (PHMAP_PREDICT_TRUE(g.MatchEmpty())) return end(); @@ -1770,7 +1742,7 @@ class raw_hash_set void erase_meta_only(const_iterator it) { assert(IsFull(*it.inner_.ctrl_) && "erasing a dangling iterator"); --size_; - const size_t index = it.inner_.ctrl_ - ctrl_; + const size_t index = (size_t)(it.inner_.ctrl_ - ctrl_); const size_t index_before = (index - Group::kWidth) & capacity_; const auto empty_after = Group(it.inner_.ctrl_).MatchEmpty(); const auto empty_before = Group(ctrl_ + index_before).MatchEmpty(); @@ -1929,14 +1901,14 @@ class raw_hash_set auto seq = probe(hash); while (true) { Group g{ctrl_ + seq.offset()}; - for (int i : g.Match(H2(hash))) { - if (PHMAP_PREDICT_TRUE(PolicyTraits::element(slots_ + seq.offset(i)) == + for (int i : g.Match((h2_t)H2(hash))) { + if (PHMAP_PREDICT_TRUE(PolicyTraits::element(slots_ + seq.offset((size_t)i)) == elem)) return true; } if (PHMAP_PREDICT_TRUE(g.MatchEmpty())) return false; seq.next(); - assert(seq.index() < capacity_ && "full table!"); + assert(seq.getindex() < capacity_ && "full table!"); } return false; } @@ -1966,9 +1938,9 @@ class raw_hash_set Group g{ctrl_ + seq.offset()}; auto mask = g.MatchEmptyOrDeleted(); if (mask) { - return {seq.offset(mask.LowestBitSet()), seq.index()}; + return {seq.offset((size_t)mask.LowestBitSet()), seq.getindex()}; } - assert(seq.index() < capacity_ && "full table!"); + assert(seq.getindex() < capacity_ && "full table!"); seq.next(); } } @@ -1991,11 +1963,11 @@ class raw_hash_set auto seq = probe(hash); while (true) { Group g{ctrl_ + seq.offset()}; - for (int i : g.Match(H2(hash))) { + for (int i : g.Match((h2_t)H2(hash))) { if (PHMAP_PREDICT_TRUE(PolicyTraits::apply( EqualElement{key, eq_ref()}, - PolicyTraits::element(slots_ + seq.offset(i))))) - return {seq.offset(i), false}; + PolicyTraits::element(slots_ + seq.offset((size_t)i))))) + return {seq.offset((size_t)i), false}; } if (PHMAP_PREDICT_TRUE(g.MatchEmpty())) break; seq.next(); @@ -2244,14 +2216,16 @@ class raw_hash_map : public raw_hash_set template MappedReference

at(const key_arg& key) { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } template MappedConstReference

at(const key_arg& key) const { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } @@ -2351,7 +2325,7 @@ class parallel_hash_set using Lockable = phmap::LockableImpl; // -------------------------------------------------------------------- - struct alignas(64) Inner : public Lockable + struct Inner : public Lockable { bool operator==(const Inner& o) const { @@ -2892,7 +2866,7 @@ class parallel_hash_set Inner& inner = sets_[subidx(hashval)]; auto& set = inner.set_; typename Lockable::UniqueLock m(inner); - return make_iterator(&inner, set.lazy_emplace(key, hashval, std::forward(f))); + return make_iterator(&inner, set.lazy_emplace_with_hash(key, hashval, std::forward(f))); } // Extension API: support for heterogeneous keys. @@ -3008,7 +2982,12 @@ class parallel_hash_set } } - void reserve(size_t n) { rehash(GrowthToLowerboundCapacity(n)); } + void reserve(size_t n) + { + size_t target = GrowthToLowerboundCapacity(n); + size_t normalized = 16 * NormalizeCapacity(n / num_tables); + rehash(normalized > target ? normalized : target); + } // Extension API: support for heterogeneous keys. // @@ -3053,11 +3032,8 @@ class parallel_hash_set // -------------------------------------------------------------------- template iterator find(const key_arg& key, size_t hashval) { - Inner& inner = sets_[subidx(hashval)]; - auto& set = inner.set_; - typename Lockable::SharedLock m(inner); - auto it = set.find(key, hashval); - return make_iterator(&inner, it); + typename Lockable::SharedLock m; + return find(key, hashval, m); } template @@ -3132,6 +3108,14 @@ class parallel_hash_set a.swap(b); } +#ifndef PHMAP_NON_DETERMINISTIC + template + bool dump(OutputArchive& ar) const; + + template + bool load(InputArchive& ar); +#endif + private: template friend struct phmap::container_internal::hashtable_debug_internal::HashtableDebugAccess; @@ -3207,6 +3191,15 @@ class parallel_hash_set } protected: + template + iterator find(const key_arg& key, size_t hashval, typename Lockable::SharedLock &mutexlock) { + Inner& inner = sets_[subidx(hashval)]; + auto& set = inner.set_; + mutexlock = std::move(typename Lockable::SharedLock(inner)); + auto it = set.find(key, hashval); + return make_iterator(&inner, it); + } + template std::tuple find_or_prepare_insert(const K& key, typename Lockable::UniqueLock &mutexlock) { @@ -3232,7 +3225,7 @@ class parallel_hash_set } template - size_t hash(const K& key) { + size_t hash(const K& key) const { return HashElement{hash_ref()}(key); } @@ -3387,17 +3380,32 @@ class parallel_hash_map : public parallel_hash_set MappedReference

at(const key_arg& key) { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } template MappedConstReference

at(const key_arg& key) const { auto it = this->find(key); - if (it == this->end()) std::abort(); + if (it == this->end()) + phmap::base_internal::ThrowStdOutOfRange("phmap at(): lookup non-existent key"); return Policy::value(&*it); } + template + bool if_contains(const key_arg& key, F&& f) const { +#if __cplusplus >= 201703L + static_assert(std::is_invocable::value); +#endif + typename Lockable::SharedLock m; + auto it = const_cast(this)->find(key, this->hash(key), m); + if (it == this->end()) + return false; + std::forward(f)(Policy::value(&*it)); + return true; + } + template MappedReference

operator[](key_arg&& key) { return Policy::value(&*try_emplace(std::forward(key)).first); @@ -3521,214 +3529,6 @@ DecomposeValue(F&& f, Arg&& arg) { } -namespace memory_internal { - -// ---------------------------------------------------------------------------- -// If Pair is a standard-layout type, OffsetOf::kFirst and -// OffsetOf::kSecond are equivalent to offsetof(Pair, first) and -// offsetof(Pair, second) respectively. Otherwise they are -1. -// -// The purpose of OffsetOf is to avoid calling offsetof() on non-standard-layout -// type, which is non-portable. -// ---------------------------------------------------------------------------- -template -struct OffsetOf { - static constexpr size_t kFirst = -1; - static constexpr size_t kSecond = -1; -}; - -template -struct OffsetOf::type> -{ - static constexpr size_t kFirst = offsetof(Pair, first); - static constexpr size_t kSecond = offsetof(Pair, second); -}; - -// ---------------------------------------------------------------------------- -template -struct IsLayoutCompatible -{ -private: - struct Pair { - K first; - V second; - }; - - // Is P layout-compatible with Pair? - template - static constexpr bool LayoutCompatible() { - return std::is_standard_layout

() && sizeof(P) == sizeof(Pair) && - alignof(P) == alignof(Pair) && - memory_internal::OffsetOf

::kFirst == - memory_internal::OffsetOf::kFirst && - memory_internal::OffsetOf

::kSecond == - memory_internal::OffsetOf::kSecond; - } - -public: - // Whether pair and pair are layout-compatible. If they are, - // then it is safe to store them in a union and read from either. - static constexpr bool value = std::is_standard_layout() && - std::is_standard_layout() && - memory_internal::OffsetOf::kFirst == 0 && - LayoutCompatible>() && - LayoutCompatible>(); -}; - -} // namespace memory_internal - -// ---------------------------------------------------------------------------- -// The internal storage type for key-value containers like flat_hash_map. -// -// It is convenient for the value_type of a flat_hash_map to be -// pair; the "const K" prevents accidental modification of the key -// when dealing with the reference returned from find() and similar methods. -// However, this creates other problems; we want to be able to emplace(K, V) -// efficiently with move operations, and similarly be able to move a -// pair in insert(). -// -// The solution is this union, which aliases the const and non-const versions -// of the pair. This also allows flat_hash_map to work, even though -// that has the same efficiency issues with move in emplace() and insert() - -// but people do it anyway. -// -// If kMutableKeys is false, only the value member can be accessed. -// -// If kMutableKeys is true, key can be accessed through all slots while value -// and mutable_value must be accessed only via INITIALIZED slots. Slots are -// created and destroyed via mutable_value so that the key can be moved later. -// -// Accessing one of the union fields while the other is active is safe as -// long as they are layout-compatible, which is guaranteed by the definition of -// kMutableKeys. For C++11, the relevant section of the standard is -// https://timsong-cpp.github.io/cppwp/n3337/class.mem#19 (9.2.19) -// ---------------------------------------------------------------------------- -template -union map_slot_type -{ - map_slot_type() {} - ~map_slot_type() = delete; - using value_type = std::pair; - using mutable_value_type = std::pair; - - value_type value; - mutable_value_type mutable_value; - K key; -}; - -// ---------------------------------------------------------------------------- -// ---------------------------------------------------------------------------- -template -struct map_slot_policy -{ - using slot_type = map_slot_type; - using value_type = std::pair; - using mutable_value_type = std::pair; - -private: - static void emplace(slot_type* slot) { - // The construction of union doesn't do anything at runtime but it allows us - // to access its members without violating aliasing rules. - new (slot) slot_type; - } - // If pair and pair are layout-compatible, we can accept one - // or the other via slot_type. We are also free to access the key via - // slot_type::key in this case. - using kMutableKeys = memory_internal::IsLayoutCompatible; - -public: - static value_type& element(slot_type* slot) { return slot->value; } - static const value_type& element(const slot_type* slot) { - return slot->value; - } - - static const K& key(const slot_type* slot) { - return kMutableKeys::value ? slot->key : slot->value.first; - } - - template - static void construct(Allocator* alloc, slot_type* slot, Args&&... args) { - emplace(slot); - if (kMutableKeys::value) { - phmap::allocator_traits::construct(*alloc, &slot->mutable_value, - std::forward(args)...); - } else { - phmap::allocator_traits::construct(*alloc, &slot->value, - std::forward(args)...); - } - } - - // Construct this slot by moving from another slot. - template - static void construct(Allocator* alloc, slot_type* slot, slot_type* other) { - emplace(slot); - if (kMutableKeys::value) { - phmap::allocator_traits::construct( - *alloc, &slot->mutable_value, std::move(other->mutable_value)); - } else { - phmap::allocator_traits::construct(*alloc, &slot->value, - std::move(other->value)); - } - } - - template - static void destroy(Allocator* alloc, slot_type* slot) { - if (kMutableKeys::value) { - phmap::allocator_traits::destroy(*alloc, &slot->mutable_value); - } else { - phmap::allocator_traits::destroy(*alloc, &slot->value); - } - } - - template - static void transfer(Allocator* alloc, slot_type* new_slot, - slot_type* old_slot) { - emplace(new_slot); - if (kMutableKeys::value) { - phmap::allocator_traits::construct( - *alloc, &new_slot->mutable_value, std::move(old_slot->mutable_value)); - } else { - phmap::allocator_traits::construct(*alloc, &new_slot->value, - std::move(old_slot->value)); - } - destroy(alloc, old_slot); - } - - template - static void swap(Allocator* alloc, slot_type* a, slot_type* b) { - if (kMutableKeys::value) { - using std::swap; - swap(a->mutable_value, b->mutable_value); - } else { - value_type tmp = std::move(a->value); - phmap::allocator_traits::destroy(*alloc, &a->value); - phmap::allocator_traits::construct(*alloc, &a->value, - std::move(b->value)); - phmap::allocator_traits::destroy(*alloc, &b->value); - phmap::allocator_traits::construct(*alloc, &b->value, - std::move(tmp)); - } - } - - template - static void move(Allocator* alloc, slot_type* src, slot_type* dest) { - if (kMutableKeys::value) { - dest->mutable_value = std::move(src->mutable_value); - } else { - phmap::allocator_traits::destroy(*alloc, &dest->value); - phmap::allocator_traits::construct(*alloc, &dest->value, - std::move(src->value)); - } - } - - template - static void move(Allocator* alloc, slot_type* first, slot_type* last, - slot_type* result) { - for (slot_type *src = first, *dest = result; src != last; ++src, ++dest) - move(alloc, src, dest); - } -}; - // -------------------------------------------------------------------------- // Policy: a policy defines how to perform different operations on // the slots of the hashtable (see hash_policy_traits.h for the full interface @@ -3977,7 +3777,7 @@ struct StringHashT size_t operator()(std::basic_string_view v) const { std::string_view bv{reinterpret_cast(v.data()), v.size() * sizeof(CharT)}; - return phmap::Hash{}(bv); + return std::hash()(bv); } }; @@ -4019,6 +3819,7 @@ struct HashEq : StringHashEqT {}; #endif // Supports heterogeneous lookup for pointers and smart pointers. +// ------------------------------------------------------------- template struct HashEq { @@ -4040,6 +3841,7 @@ struct HashEq private: static const T* ToPtr(const T* ptr) { return ptr; } + template static const T* ToPtr(const std::unique_ptr& ptr) { return ptr.get(); @@ -4078,7 +3880,7 @@ struct HashtableDebugAccess> if (Traits::apply( typename Set::template EqualElement{ key, set.eq_ref()}, - Traits::element(set.slots_ + seq.offset(i)))) + Traits::element(set.slots_ + seq.offset((size_t)i)))) return num_probes; ++num_probes; } @@ -4131,7 +3933,6 @@ struct HashtableDebugAccess> // Its interface is similar to that of `std::unordered_set` with the // following notable differences: // -// * Requires keys that are CopyConstructible // * Supports heterogeneous lookup, through `find()`, `operator[]()` and // `insert()`, provided that the set is provided a compatible heterogeneous // hashing function and equality operator. @@ -4139,7 +3940,7 @@ struct HashtableDebugAccess> // `rehash()`. // * Contains a `capacity()` member function indicating the number of element // slots (open, deleted, and empty) within the hash set. -// * Returns `void` from the `erase(iterator)` overload. +// * Returns `void` from the `_erase(iterator)` overload. // ----------------------------------------------------------------------------- template // default values in phmap_fwd_decl.h class flat_hash_set @@ -4194,8 +3995,6 @@ class flat_hash_set // cases. Its interface is similar to that of `std::unordered_map` with // the following notable differences: // -// * Requires keys that are CopyConstructible -// * Requires values that are MoveConstructible // * Supports heterogeneous lookup, through `find()`, `operator[]()` and // `insert()`, provided that the map is provided a compatible heterogeneous // hashing function and equality operator. @@ -4203,7 +4002,7 @@ class flat_hash_set // `rehash()`. // * Contains a `capacity()` member function indicating the number of element // slots (open, deleted, and empty) within the hash map. -// * Returns `void` from the `erase(iterator)` overload. +// * Returns `void` from the `_erase(iterator)` overload. // ----------------------------------------------------------------------------- template // default values in phmap_fwd_decl.h class flat_hash_map : public phmap::container_internal::raw_hash_map< @@ -4600,4 +4399,9 @@ class parallel_node_hash_map } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + + #endif // phmap_h_guard_ diff --git a/include/parallel_hashmap/phmap_base.h b/include/parallel_hashmap/phmap_base.h index bbc6712d9..27976826c 100644 --- a/include/parallel_hashmap/phmap_base.h +++ b/include/parallel_hashmap/phmap_base.h @@ -54,6 +54,17 @@ #include // after "phmap_config.h" #endif +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4582) // constructor is not implicitly called + #pragma warning(disable : 4625) // copy constructor was implicitly defined as deleted + #pragma warning(disable : 4626) // assignment operator was implicitly defined as deleted + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion + #pragma warning(disable : 4820) // '6' bytes padding added after data member +#endif // _MSC_VER + namespace phmap { template using Allocator = typename std::allocator; @@ -69,6 +80,15 @@ struct EqualTo } }; +template +struct Less +{ + inline bool operator()(const T& a, const T& b) const + { + return std::less()(a, b); + } +}; + namespace type_traits_internal { template @@ -213,177 +233,31 @@ struct disjunction : T {}; template <> struct disjunction<> : std::false_type {}; -// --------------------------------------------------------------------------- -// negation -// -// Performs a compile-time logical NOT operation on the passed type (which -// must have `::value` members convertible to `bool`. -// -// This metafunction is designed to be a drop-in replacement for the C++17 -// `std::negation` metafunction. -// --------------------------------------------------------------------------- template struct negation : std::integral_constant {}; -// --------------------------------------------------------------------------- -// is_trivially_destructible() -// -// Determines whether the passed type `T` is trivially destructable. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_destructible()` metafunction for platforms that have -// incomplete C++11 support (such as libstdc++ 4.x). On any platforms that do -// fully support C++11, we check whether this yields the same result as the std -// implementation. -// -// NOTE: the extensions (__has_trivial_xxx) are implemented in gcc (version >= -// 4.3) and clang. Since we are supporting libstdc++ > 4.7, they should always -// be present. These extensions are documented at -// https://gcc.gnu.org/onlinedocs/gcc/Type-Traits.html#Type-Traits. -// --------------------------------------------------------------------------- template struct is_trivially_destructible : std::integral_constant::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_DESTRUCTIBLE -private: - static constexpr bool compliant = std::is_trivially_destructible::value == - is_trivially_destructible::value; - static_assert(compliant || std::is_trivially_destructible::value, - "Not compliant with std::is_trivially_destructible; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_destructible::value, - "Not compliant with std::is_trivially_destructible; " - "Standard: true, Implementation: false"); -#endif -}; + std::is_destructible::value> {}; -// --------------------------------------------------------------------------- -// is_trivially_default_constructible() -// -// Determines whether the passed type `T` is trivially default constructible. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_default_constructible()` metafunction for platforms that -// have incomplete C++11 support (such as libstdc++ 4.x). On any platforms that -// do fully support C++11, we check whether this yields the same result as the -// std implementation. -// -// NOTE: according to the C++ standard, Section: 20.15.4.3 [meta.unary.prop] -// "The predicate condition for a template specialization is_constructible shall be satisfied if and only if the following variable -// definition would be well-formed for some invented variable t: -// -// T t(declval()...); -// -// is_trivially_constructible additionally requires that the -// variable definition does not call any operation that is not trivial. -// For the purposes of this check, the call to std::declval is considered -// trivial." -// -// Notes from https://en.cppreference.com/w/cpp/types/is_constructible: -// In many implementations, is_nothrow_constructible also checks if the -// destructor throws because it is effectively noexcept(T(arg)). Same -// applies to is_trivially_constructible, which, in these implementations, also -// requires that the destructor is trivial. -// GCC bug 51452: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=51452 -// LWG issue 2116: http://cplusplus.github.io/LWG/lwg-active.html#2116. -// -// "T obj();" need to be well-formed and not call any nontrivial operation. -// Nontrivially destructible types will cause the expression to be nontrivial. -// --------------------------------------------------------------------------- template struct is_trivially_default_constructible : std::integral_constant::value && - is_trivially_destructible::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE -private: - static constexpr bool compliant = - std::is_trivially_default_constructible::value == - is_trivially_default_constructible::value; - static_assert(compliant || std::is_trivially_default_constructible::value, - "Not compliant with std::is_trivially_default_constructible; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_default_constructible::value, - "Not compliant with std::is_trivially_default_constructible; " - "Standard: true, Implementation: false"); -#endif -}; + is_trivially_destructible::value> {}; -// --------------------------------------------------------------------------- -// is_trivially_copy_constructible() -// -// Determines whether the passed type `T` is trivially copy constructible. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_copy_constructible()` metafunction for platforms that have -// incomplete C++11 support (such as libstdc++ 4.x). On any platforms that do -// fully support C++11, we check whether this yields the same result as the std -// implementation. -// -// NOTE: `T obj(declval());` needs to be well-formed and not call any -// nontrivial operation. Nontrivially destructible types will cause the -// expression to be nontrivial. -// --------------------------------------------------------------------------- template struct is_trivially_copy_constructible : std::integral_constant::value && - is_trivially_destructible::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_CONSTRUCTIBLE -private: - static constexpr bool compliant = - std::is_trivially_copy_constructible::value == - is_trivially_copy_constructible::value; - static_assert(compliant || std::is_trivially_copy_constructible::value, - "Not compliant with std::is_trivially_copy_constructible; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_copy_constructible::value, - "Not compliant with std::is_trivially_copy_constructible; " - "Standard: true, Implementation: false"); -#endif -}; + is_trivially_destructible::value> {}; -// --------------------------------------------------------------------------- -// is_trivially_copy_assignable() -// -// Determines whether the passed type `T` is trivially copy assignable. -// -// This metafunction is designed to be a drop-in replacement for the C++11 -// `std::is_trivially_copy_assignable()` metafunction for platforms that have -// incomplete C++11 support (such as libstdc++ 4.x). On any platforms that do -// fully support C++11, we check whether this yields the same result as the std -// implementation. -// -// NOTE: `is_assignable::value` is `true` if the expression -// `declval() = declval()` is well-formed when treated as an unevaluated -// operand. `is_trivially_assignable` requires the assignment to call no -// operation that is not trivial. `is_trivially_copy_assignable` is simply -// `is_trivially_assignable`. -// --------------------------------------------------------------------------- template struct is_trivially_copy_assignable : std::integral_constant< bool, __has_trivial_assign(typename std::remove_reference::type) && - phmap::is_copy_assignable::value> -{ -#ifdef PHMAP_HAVE_STD_IS_TRIVIALLY_ASSIGNABLE -private: - static constexpr bool compliant = - std::is_trivially_copy_assignable::value == - is_trivially_copy_assignable::value; - static_assert(compliant || std::is_trivially_copy_assignable::value, - "Not compliant with std::is_trivially_copy_assignable; " - "Standard: false, Implementation: true"); - static_assert(compliant || !std::is_trivially_copy_assignable::value, - "Not compliant with std::is_trivially_copy_assignable; " - "Standard: true, Implementation: false"); -#endif -}; + phmap::is_copy_assignable::value> {}; // ----------------------------------------------------------------------------- // C++14 "_t" trait aliases @@ -447,14 +321,19 @@ using enable_if_t = typename std::enable_if::type; template using conditional_t = typename std::conditional::type; + template using common_type_t = typename std::common_type::type; template using underlying_type_t = typename std::underlying_type::type; -template -using result_of_t = typename std::result_of::type; +template< class F, class... ArgTypes> +#if __cplusplus >= 201703L +using invoke_result_t = typename std::invoke_result_t; +#else +using invoke_result_t = typename std::result_of::type; +#endif namespace type_traits_internal { @@ -1226,6 +1105,11 @@ auto apply(Functor&& functor, Tuple&& t) typename std::remove_reference::type>::value>{}); } +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4365) // '=': conversion from 'T' to 'T', signed/unsigned mismatch +#endif // _MSC_VER + // exchange // // Replaces the value of `obj` with `new_value` and returns the old value of @@ -1247,6 +1131,11 @@ T exchange(T& obj, U&& new_value) return old_value; } +#ifdef _MSC_VER + #pragma warning(pop) +#endif // _MSC_VER + + } // namespace phmap // ----------------------------------------------------------------------------- @@ -2845,6 +2734,12 @@ struct KeyArg using type = key_type; }; +#ifdef _MSC_VER + #pragma warning(push) + // warning C4820: '6' bytes padding added after data member + #pragma warning(disable : 4820) +#endif + // The node_handle concept from C++17. // We specialize node_handle for sets and maps. node_handle_base holds the // common API of both. @@ -2883,10 +2778,25 @@ class node_handle_base protected: friend struct CommonAccess; + struct transfer_tag_t {}; + node_handle_base(transfer_tag_t, const allocator_type& a, slot_type* s) + : alloc_(a) { + PolicyTraits::transfer(alloc(), slot(), s); + } + + struct move_tag_t {}; + node_handle_base(move_tag_t, const allocator_type& a, slot_type* s) + : alloc_(a) { + PolicyTraits::construct(alloc(), slot(), s); + } + node_handle_base(const allocator_type& a, slot_type* s) : alloc_(a) { PolicyTraits::transfer(alloc(), slot(), s); } + //node_handle_base(const node_handle_base&) = delete; + //node_handle_base& operator=(const node_handle_base&) = delete; + void destroy() { if (!empty()) { PolicyTraits::destroy(alloc(), slot()); @@ -2908,10 +2818,13 @@ class node_handle_base private: phmap::optional alloc_; - mutable phmap::aligned_storage_t - slot_space_; + mutable phmap::aligned_storage_t slot_space_; }; +#ifdef _MSC_VER + #pragma warning(pop) +#endif + // For sets. // --------- template private: friend struct CommonAccess; - node_handle(const Alloc& a, typename Base::slot_type* s) : Base(a, s) {} + using Base::Base; }; // For maps. @@ -2961,7 +2874,7 @@ class node_handle + static void Destroy(Node* node) { + node->destroy(); + } + template static void Reset(Node* node) { node->reset(); @@ -2981,6 +2899,16 @@ struct CommonAccess static T Make(Args&&... args) { return T(std::forward(args)...); } + + template + static T Transfer(Args&&... args) { + return T(typename T::transfer_tag_t{}, std::forward(args)...); + } + + template + static T Move(Args&&... args) { + return T(typename T::move_tag_t{}, std::forward(args)...); + } }; // Implement the insert_return_type<> concept of C++17. @@ -4404,6 +4332,98 @@ class PHMAP_INTERNAL_COMPRESSED_TUPLE_DECLSPEC CompressedTuple<> {}; } // namespace container_internal } // namespace phmap + +namespace phmap { +namespace container_internal { + +#ifdef _MSC_VER + #pragma warning(push) + // warning warning C4324: structure was padded due to alignment specifier + #pragma warning(disable : 4324) +#endif + + +// ---------------------------------------------------------------------------- +// Allocates at least n bytes aligned to the specified alignment. +// Alignment must be a power of 2. It must be positive. +// +// Note that many allocators don't honor alignment requirements above certain +// threshold (usually either alignof(std::max_align_t) or alignof(void*)). +// Allocate() doesn't apply alignment corrections. If the underlying allocator +// returns insufficiently alignment pointer, that's what you are going to get. +// ---------------------------------------------------------------------------- +template +void* Allocate(Alloc* alloc, size_t n) { + static_assert(Alignment > 0, ""); + assert(n && "n must be positive"); + struct alignas(Alignment) M {}; + using A = typename phmap::allocator_traits::template rebind_alloc; + using AT = typename phmap::allocator_traits::template rebind_traits; + A mem_alloc(*alloc); + void* p = AT::allocate(mem_alloc, (n + sizeof(M) - 1) / sizeof(M)); + assert(reinterpret_cast(p) % Alignment == 0 && + "allocator does not respect alignment"); + return p; +} + +// ---------------------------------------------------------------------------- +// The pointer must have been previously obtained by calling +// Allocate(alloc, n). +// ---------------------------------------------------------------------------- +template +void Deallocate(Alloc* alloc, void* p, size_t n) { + static_assert(Alignment > 0, ""); + assert(n && "n must be positive"); + struct alignas(Alignment) M {}; + using A = typename phmap::allocator_traits::template rebind_alloc; + using AT = typename phmap::allocator_traits::template rebind_traits; + A mem_alloc(*alloc); + AT::deallocate(mem_alloc, static_cast(p), + (n + sizeof(M) - 1) / sizeof(M)); +} + +#ifdef _MSC_VER + #pragma warning(pop) +#endif + +// Helper functions for asan and msan. +// ---------------------------------------------------------------------------- +inline void SanitizerPoisonMemoryRegion(const void* m, size_t s) { +#ifdef ADDRESS_SANITIZER + ASAN_POISON_MEMORY_REGION(m, s); +#endif +#ifdef MEMORY_SANITIZER + __msan_poison(m, s); +#endif + (void)m; + (void)s; +} + +inline void SanitizerUnpoisonMemoryRegion(const void* m, size_t s) { +#ifdef ADDRESS_SANITIZER + ASAN_UNPOISON_MEMORY_REGION(m, s); +#endif +#ifdef MEMORY_SANITIZER + __msan_unpoison(m, s); +#endif + (void)m; + (void)s; +} + +template +inline void SanitizerPoisonObject(const T* object) { + SanitizerPoisonMemoryRegion(object, sizeof(T)); +} + +template +inline void SanitizerUnpoisonObject(const T* object) { + SanitizerUnpoisonMemoryRegion(object, sizeof(T)); +} + +} // namespace container_internal +} // namespace phmap + + // --------------------------------------------------------------------------- // thread_annotations.h // --------------------------------------------------------------------------- @@ -4513,6 +4533,221 @@ inline T& ts_unchecked_read(T& v) PHMAP_NO_THREAD_SAFETY_ANALYSIS { } } // namespace thread_safety_analysis + +namespace container_internal { + +namespace memory_internal { + +// ---------------------------------------------------------------------------- +// If Pair is a standard-layout type, OffsetOf::kFirst and +// OffsetOf::kSecond are equivalent to offsetof(Pair, first) and +// offsetof(Pair, second) respectively. Otherwise they are -1. +// +// The purpose of OffsetOf is to avoid calling offsetof() on non-standard-layout +// type, which is non-portable. +// ---------------------------------------------------------------------------- +template +struct OffsetOf { + static constexpr size_t kFirst = (size_t)-1; + static constexpr size_t kSecond = (size_t)-1; +}; + +template +struct OffsetOf::type> +{ + static constexpr size_t kFirst = offsetof(Pair, first); + static constexpr size_t kSecond = offsetof(Pair, second); +}; + +// ---------------------------------------------------------------------------- +template +struct IsLayoutCompatible +{ +private: + struct Pair { + K first; + V second; + }; + + // Is P layout-compatible with Pair? + template + static constexpr bool LayoutCompatible() { + return std::is_standard_layout

() && sizeof(P) == sizeof(Pair) && + alignof(P) == alignof(Pair) && + memory_internal::OffsetOf

::kFirst == + memory_internal::OffsetOf::kFirst && + memory_internal::OffsetOf

::kSecond == + memory_internal::OffsetOf::kSecond; + } + +public: + // Whether pair and pair are layout-compatible. If they are, + // then it is safe to store them in a union and read from either. + static constexpr bool value = std::is_standard_layout() && + std::is_standard_layout() && + memory_internal::OffsetOf::kFirst == 0 && + LayoutCompatible>() && + LayoutCompatible>(); +}; + +} // namespace memory_internal + +// ---------------------------------------------------------------------------- +// The internal storage type for key-value containers like flat_hash_map. +// +// It is convenient for the value_type of a flat_hash_map to be +// pair; the "const K" prevents accidental modification of the key +// when dealing with the reference returned from find() and similar methods. +// However, this creates other problems; we want to be able to emplace(K, V) +// efficiently with move operations, and similarly be able to move a +// pair in insert(). +// +// The solution is this union, which aliases the const and non-const versions +// of the pair. This also allows flat_hash_map to work, even though +// that has the same efficiency issues with move in emplace() and insert() - +// but people do it anyway. +// +// If kMutableKeys is false, only the value member can be accessed. +// +// If kMutableKeys is true, key can be accessed through all slots while value +// and mutable_value must be accessed only via INITIALIZED slots. Slots are +// created and destroyed via mutable_value so that the key can be moved later. +// +// Accessing one of the union fields while the other is active is safe as +// long as they are layout-compatible, which is guaranteed by the definition of +// kMutableKeys. For C++11, the relevant section of the standard is +// https://timsong-cpp.github.io/cppwp/n3337/class.mem#19 (9.2.19) +// ---------------------------------------------------------------------------- +template +union map_slot_type +{ + map_slot_type() {} + ~map_slot_type() = delete; + map_slot_type(const map_slot_type&) = delete; + map_slot_type& operator=(const map_slot_type&) = delete; + + using value_type = std::pair; + using mutable_value_type = std::pair; + + value_type value; + mutable_value_type mutable_value; + K key; +}; + +// ---------------------------------------------------------------------------- +// ---------------------------------------------------------------------------- +template +struct map_slot_policy +{ + using slot_type = map_slot_type; + using value_type = std::pair; + using mutable_value_type = std::pair; + +private: + static void emplace(slot_type* slot) { + // The construction of union doesn't do anything at runtime but it allows us + // to access its members without violating aliasing rules. + new (slot) slot_type; + } + // If pair and pair are layout-compatible, we can accept one + // or the other via slot_type. We are also free to access the key via + // slot_type::key in this case. + using kMutableKeys = memory_internal::IsLayoutCompatible; + +public: + static value_type& element(slot_type* slot) { return slot->value; } + static const value_type& element(const slot_type* slot) { + return slot->value; + } + + static const K& key(const slot_type* slot) { + return kMutableKeys::value ? slot->key : slot->value.first; + } + + template + static void construct(Allocator* alloc, slot_type* slot, Args&&... args) { + emplace(slot); + if (kMutableKeys::value) { + phmap::allocator_traits::construct(*alloc, &slot->mutable_value, + std::forward(args)...); + } else { + phmap::allocator_traits::construct(*alloc, &slot->value, + std::forward(args)...); + } + } + + // Construct this slot by moving from another slot. + template + static void construct(Allocator* alloc, slot_type* slot, slot_type* other) { + emplace(slot); + if (kMutableKeys::value) { + phmap::allocator_traits::construct( + *alloc, &slot->mutable_value, std::move(other->mutable_value)); + } else { + phmap::allocator_traits::construct(*alloc, &slot->value, + std::move(other->value)); + } + } + + template + static void destroy(Allocator* alloc, slot_type* slot) { + if (kMutableKeys::value) { + phmap::allocator_traits::destroy(*alloc, &slot->mutable_value); + } else { + phmap::allocator_traits::destroy(*alloc, &slot->value); + } + } + + template + static void transfer(Allocator* alloc, slot_type* new_slot, + slot_type* old_slot) { + emplace(new_slot); + if (kMutableKeys::value) { + phmap::allocator_traits::construct( + *alloc, &new_slot->mutable_value, std::move(old_slot->mutable_value)); + } else { + phmap::allocator_traits::construct(*alloc, &new_slot->value, + std::move(old_slot->value)); + } + destroy(alloc, old_slot); + } + + template + static void swap(Allocator* alloc, slot_type* a, slot_type* b) { + if (kMutableKeys::value) { + using std::swap; + swap(a->mutable_value, b->mutable_value); + } else { + value_type tmp = std::move(a->value); + phmap::allocator_traits::destroy(*alloc, &a->value); + phmap::allocator_traits::construct(*alloc, &a->value, + std::move(b->value)); + phmap::allocator_traits::destroy(*alloc, &b->value); + phmap::allocator_traits::construct(*alloc, &b->value, + std::move(tmp)); + } + } + + template + static void move(Allocator* alloc, slot_type* src, slot_type* dest) { + if (kMutableKeys::value) { + dest->mutable_value = std::move(src->mutable_value); + } else { + phmap::allocator_traits::destroy(*alloc, &dest->value); + phmap::allocator_traits::construct(*alloc, &dest->value, + std::move(src->value)); + } + } + + template + static void move(Allocator* alloc, slot_type* first, slot_type* last, + slot_type* result) { + for (slot_type *src = first, *dest = result; src != last; ++src, ++dest) + move(alloc, src, dest); + } +}; + +} // namespace container_internal } // phmap @@ -4928,4 +5163,9 @@ class LockableImpl: public phmap::NullMutex } // phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + + #endif // phmap_base_h_guard_ diff --git a/include/parallel_hashmap/phmap_bits.h b/include/parallel_hashmap/phmap_bits.h index c48e8eac0..7933d8cb5 100644 --- a/include/parallel_hashmap/phmap_bits.h +++ b/include/parallel_hashmap/phmap_bits.h @@ -50,6 +50,11 @@ #include #include "phmap_config.h" +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed +#endif + // ----------------------------------------------------------------------------- // unaligned APIs // ----------------------------------------------------------------------------- @@ -151,6 +156,102 @@ inline void UnalignedStore64(void *p, uint64_t v) { memcpy(p, &v, sizeof v); } #endif +// ----------------------------------------------------------------------------- +// File: optimization.h +// ----------------------------------------------------------------------------- + +#if defined(__pnacl__) + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } +#elif defined(__clang__) + // Clang will not tail call given inline volatile assembly. + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") +#elif defined(__GNUC__) + // GCC will not tail call given inline volatile assembly. + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") +#elif defined(_MSC_VER) + #include + // The __nop() intrinsic blocks the optimisation. + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __nop() +#else + #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } +#endif + +#if defined(__GNUC__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" +#endif + +#ifdef PHMAP_HAVE_INTRINSIC_INT128 + __extension__ typedef unsigned __int128 phmap_uint128; + inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) + { + auto result = static_cast(a) * static_cast(b); + *high = static_cast(result >> 64); + return static_cast(result); + } + #define PHMAP_HAS_UMUL128 1 +#elif (defined(_MSC_VER)) + #if defined(_M_X64) + #pragma intrinsic(_umul128) + inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) + { + return _umul128(a, b, high); + } + #define PHMAP_HAS_UMUL128 1 + #endif +#endif + +#if defined(__GNUC__) + #pragma GCC diagnostic pop +#endif + +#if defined(__GNUC__) + // Cache line alignment + #if defined(__i386__) || defined(__x86_64__) + #define PHMAP_CACHELINE_SIZE 64 + #elif defined(__powerpc64__) + #define PHMAP_CACHELINE_SIZE 128 + #elif defined(__aarch64__) + // We would need to read special register ctr_el0 to find out L1 dcache size. + // This value is a good estimate based on a real aarch64 machine. + #define PHMAP_CACHELINE_SIZE 64 + #elif defined(__arm__) + // Cache line sizes for ARM: These values are not strictly correct since + // cache line sizes depend on implementations, not architectures. There + // are even implementations with cache line sizes configurable at boot + // time. + #if defined(__ARM_ARCH_5T__) + #define PHMAP_CACHELINE_SIZE 32 + #elif defined(__ARM_ARCH_7A__) + #define PHMAP_CACHELINE_SIZE 64 + #endif + #endif + + #ifndef PHMAP_CACHELINE_SIZE + // A reasonable default guess. Note that overestimates tend to waste more + // space, while underestimates tend to waste more time. + #define PHMAP_CACHELINE_SIZE 64 + #endif + + #define PHMAP_CACHELINE_ALIGNED __attribute__((aligned(PHMAP_CACHELINE_SIZE))) +#elif defined(_MSC_VER) + #define PHMAP_CACHELINE_SIZE 64 + #define PHMAP_CACHELINE_ALIGNED __declspec(align(PHMAP_CACHELINE_SIZE)) +#else + #define PHMAP_CACHELINE_SIZE 64 + #define PHMAP_CACHELINE_ALIGNED +#endif + + +#if PHMAP_HAVE_BUILTIN(__builtin_expect) || \ + (defined(__GNUC__) && !defined(__clang__)) + #define PHMAP_PREDICT_FALSE(x) (__builtin_expect(x, 0)) + #define PHMAP_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else + #define PHMAP_PREDICT_FALSE(x) (x) + #define PHMAP_PREDICT_TRUE(x) (x) +#endif + // ----------------------------------------------------------------------------- // File: bits.h // ----------------------------------------------------------------------------- @@ -183,7 +284,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros64(uint64_t n) { // MSVC does not have __buitin_clzll. Use _BitScanReverse64. unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse64(&result, n)) { - return 63 - result; + return (int)(63 - result); } return 64; #elif defined(_MSC_VER) @@ -226,7 +327,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountLeadingZeros32(uint32_t n) { #if defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) if (_BitScanReverse(&result, n)) { - return 31 - result; + return (int)(31 - result); } return 32; #elif defined(__GNUC__) @@ -263,7 +364,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero64(uint64_t n) { #if defined(_MSC_VER) && defined(_M_X64) unsigned long result = 0; // NOLINT(runtime/int) _BitScanForward64(&result, n); - return result; + return (int)result; #elif defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) if (static_cast(n) == 0) { @@ -296,7 +397,7 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) { #if defined(_MSC_VER) unsigned long result = 0; // NOLINT(runtime/int) _BitScanForward(&result, n); - return result; + return (int)result; #elif defined(__GNUC__) static_assert(sizeof(int) == sizeof(n), "__builtin_ctz does not take 32-bit arg"); @@ -311,101 +412,6 @@ PHMAP_BASE_INTERNAL_FORCEINLINE int CountTrailingZerosNonZero32(uint32_t n) { } // namespace base_internal } // namespace phmap -// ----------------------------------------------------------------------------- -// File: optimization.h -// ----------------------------------------------------------------------------- - -#if defined(__pnacl__) - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } -#elif defined(__clang__) - // Clang will not tail call given inline volatile assembly. - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") -#elif defined(__GNUC__) - // GCC will not tail call given inline volatile assembly. - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __asm__ __volatile__("") -#elif defined(_MSC_VER) - #include - // The __nop() intrinsic blocks the optimisation. - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() __nop() -#else - #define PHMAP_BLOCK_TAIL_CALL_OPTIMIZATION() if (volatile int x = 0) { (void)x; } -#endif - -#if defined(__GNUC__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wpedantic" -#endif - -#ifdef PHMAP_HAVE_INTRINSIC_INT128 - inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) - { - auto result = static_cast(a) * static_cast(b); - *high = static_cast(result >> 64); - return static_cast(result); - } - #define PHMAP_HAS_UMUL128 1 -#elif (defined(_MSC_VER)) - #if defined(_M_X64) - #pragma intrinsic(_umul128) - inline uint64_t umul128(uint64_t a, uint64_t b, uint64_t* high) - { - return _umul128(a, b, high); - } - #define PHMAP_HAS_UMUL128 1 - #endif -#endif - -#if defined(__GNUC__) - #pragma GCC diagnostic pop -#endif - -#if defined(__GNUC__) - // Cache line alignment - #if defined(__i386__) || defined(__x86_64__) - #define PHMAP_CACHELINE_SIZE 64 - #elif defined(__powerpc64__) - #define PHMAP_CACHELINE_SIZE 128 - #elif defined(__aarch64__) - // We would need to read special register ctr_el0 to find out L1 dcache size. - // This value is a good estimate based on a real aarch64 machine. - #define PHMAP_CACHELINE_SIZE 64 - #elif defined(__arm__) - // Cache line sizes for ARM: These values are not strictly correct since - // cache line sizes depend on implementations, not architectures. There - // are even implementations with cache line sizes configurable at boot - // time. - #if defined(__ARM_ARCH_5T__) - #define PHMAP_CACHELINE_SIZE 32 - #elif defined(__ARM_ARCH_7A__) - #define PHMAP_CACHELINE_SIZE 64 - #endif - #endif - - #ifndef PHMAP_CACHELINE_SIZE - // A reasonable default guess. Note that overestimates tend to waste more - // space, while underestimates tend to waste more time. - #define PHMAP_CACHELINE_SIZE 64 - #endif - - #define PHMAP_CACHELINE_ALIGNED __attribute__((aligned(PHMAP_CACHELINE_SIZE))) -#elif defined(_MSC_VER) - #define PHMAP_CACHELINE_SIZE 64 - #define PHMAP_CACHELINE_ALIGNED __declspec(align(PHMAP_CACHELINE_SIZE)) -#else - #define PHMAP_CACHELINE_SIZE 64 - #define PHMAP_CACHELINE_ALIGNED -#endif - - -#if PHMAP_HAVE_BUILTIN(__builtin_expect) || \ - (defined(__GNUC__) && !defined(__clang__)) - #define PHMAP_PREDICT_FALSE(x) (__builtin_expect(x, 0)) - #define PHMAP_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) -#else - #define PHMAP_PREDICT_FALSE(x) (x) - #define PHMAP_PREDICT_TRUE(x) (x) -#endif - // ----------------------------------------------------------------------------- // File: endian.h // ----------------------------------------------------------------------------- @@ -650,4 +656,8 @@ inline void Store64(void *p, uint64_t v) { } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + #endif // phmap_bits_h_guard_ diff --git a/include/parallel_hashmap/phmap_config.h b/include/parallel_hashmap/phmap_config.h index ff10b4393..dd5a0bf54 100644 --- a/include/parallel_hashmap/phmap_config.h +++ b/include/parallel_hashmap/phmap_config.h @@ -296,7 +296,8 @@ #endif #ifdef __has_include - #if __has_include() && __cplusplus >= 201703L + #if __has_include() && __cplusplus >= 201703L && \ + (!defined(_MSC_VER) || _MSC_VER >= 1920) // vs2019 #define PHMAP_HAVE_STD_STRING_VIEW 1 #endif #endif @@ -304,14 +305,16 @@ // #pragma message(PHMAP_VAR_NAME_VALUE(_MSVC_LANG)) #if defined(_MSC_VER) && _MSC_VER >= 1910 && \ - ((defined(_MSVC_LANG) && _MSVC_LANG > 201402) || __cplusplus > 201402) + ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703) || __cplusplus >= 201703) // #define PHMAP_HAVE_STD_ANY 1 #define PHMAP_HAVE_STD_OPTIONAL 1 #define PHMAP_HAVE_STD_VARIANT 1 - #define PHMAP_HAVE_STD_STRING_VIEW 1 + #if !defined(PHMAP_HAVE_STD_STRING_VIEW) && _MSC_VER >= 1920 + #define PHMAP_HAVE_STD_STRING_VIEW 1 + #endif #endif -#if (defined(_MSVC_LANG) && _MSVC_LANG >= 201402) || __cplusplus >= 201703 +#if (defined(_MSVC_LANG) && _MSVC_LANG >= 201703) || __cplusplus >= 201703 #define PHMAP_HAVE_SHARED_MUTEX 1 #endif @@ -391,7 +394,7 @@ #define PHMAP_ATTRIBUTE_ALWAYS_INLINE #endif -#if PHMAP_HAVE_ATTRIBUTE(noinline) || (defined(__GNUC__) && !defined(__clang__)) +#if !defined(__INTEL_COMPILER) && (PHMAP_HAVE_ATTRIBUTE(noinline) || (defined(__GNUC__) && !defined(__clang__))) #define PHMAP_ATTRIBUTE_NOINLINE __attribute__((noinline)) #define PHMAP_HAVE_ATTRIBUTE_NOINLINE 1 #else @@ -627,6 +630,15 @@ #endif +// ---------------------------------------------------------------------- +// constexpr if +// ---------------------------------------------------------------------- +#if __cplusplus >= 201703 || (defined(_MSVC_LANG) && _MSVC_LANG >= 201703) + #define PHMAP_IF_CONSTEXPR(expr) if constexpr ((expr)) +#else + #define PHMAP_IF_CONSTEXPR(expr) if ((expr)) +#endif + // ---------------------------------------------------------------------- // base/macros.h // ---------------------------------------------------------------------- diff --git a/include/parallel_hashmap/phmap_fwd_decl.h b/include/parallel_hashmap/phmap_fwd_decl.h index a9fb51418..f90417d53 100644 --- a/include/parallel_hashmap/phmap_fwd_decl.h +++ b/include/parallel_hashmap/phmap_fwd_decl.h @@ -11,13 +11,30 @@ // https://www.apache.org/licenses/LICENSE-2.0 // --------------------------------------------------------------------------- +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion +#endif + #include #include +#if defined(PHMAP_USE_ABSL_HASH) && !defined(ABSL_HASH_HASH_H_) + namespace absl { template struct Hash; }; +#endif + namespace phmap { +#if defined(PHMAP_USE_ABSL_HASH) + template using Hash = ::absl::Hash; +#else template struct Hash; +#endif + template struct EqualTo; + template struct Less; template using Allocator = typename std::allocator; template using Pair = typename std::pair; @@ -29,13 +46,8 @@ namespace phmap { template struct HashEq { -#if defined(PHMAP_USE_ABSL_HASHEQ) - using Hash = absl::Hash; - using Eq = phmap::EqualTo; -#else using Hash = phmap::Hash; using Eq = phmap::EqualTo; -#endif }; template @@ -54,6 +66,7 @@ namespace phmap { } // namespace container_internal + // ------------- forward declarations for hash containers ---------------------------------- template , class Eq = phmap::container_internal::hash_default_eq, @@ -114,9 +127,28 @@ namespace phmap { class Mutex = phmap::NullMutex> // use std::mutex to enable internal locks class parallel_node_hash_map; + // ------------- forward declarations for btree containers ---------------------------------- + template , + typename Alloc = phmap::Allocator> + class btree_set; + template , + typename Alloc = phmap::Allocator> + class btree_multiset; + + template , + typename Alloc = phmap::Allocator>> + class btree_map; + + template , + typename Alloc = phmap::Allocator>> + class btree_multimap; } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif + #endif // phmap_fwd_decl_h_guard_ diff --git a/include/parallel_hashmap/phmap_utils.h b/include/parallel_hashmap/phmap_utils.h index 96fb6ff39..72d4e716f 100644 --- a/include/parallel_hashmap/phmap_utils.h +++ b/include/parallel_hashmap/phmap_utils.h @@ -4,7 +4,9 @@ // --------------------------------------------------------------------------- // Copyright (c) 2019, Gregory Popovitch - greg7mdp@gmail.com // -// minimal header providing phmap::hash_combine +// minimal header providing phmap::HashState +// +// use as: phmap::HashState().combine(0, _first_name, _last_name, _age); // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,10 +21,25 @@ // limitations under the License. // --------------------------------------------------------------------------- +#ifdef _MSC_VER + #pragma warning(push) + #pragma warning(disable : 4514) // unreferenced inline function has been removed + #pragma warning(disable : 4710) // function not inlined + #pragma warning(disable : 4711) // selected for automatic inline expansion +#endif + #include #include +#include #include "phmap_bits.h" +// --------------------------------------------------------------- +// Absl forward declaration requires global scope. +// --------------------------------------------------------------- +#if defined(PHMAP_USE_ABSL_HASH) && !defined(phmap_fwd_decl_h_guard_) && !defined(ABSL_HASH_HASH_H_) + namespace absl { template struct Hash; }; +#endif + namespace phmap { @@ -113,14 +130,17 @@ struct has_hash_value typedef std::true_type yes; typedef std::false_type no; - template static auto test(int) -> decltype(U::hash_value(std::declval()) == 1, yes()); + template static auto test(int) -> decltype(hash_value(std::declval()) == 1, yes()); template static no test(...); public: static constexpr bool value = std::is_same(0)), yes>::value; }; - + +#if defined(PHMAP_USE_ABSL_HASH) && !defined(phmap_fwd_decl_h_guard_) + template using Hash = ::absl::Hash; +#elif !defined(PHMAP_USE_ABSL_HASH) // --------------------------------------------------------------- // phmap::Hash // --------------------------------------------------------------- @@ -130,7 +150,7 @@ struct Hash template ::value, int>::type = 0> size_t _hash(const T& val) const { - return U::hash_value(val); + return hash_value(val); } template ::value, int>::type = 0> @@ -262,6 +282,8 @@ struct Hash : public phmap_unary_function } }; +#endif + template struct Combiner { H operator()(H seed, size_t value); @@ -283,7 +305,7 @@ template struct Combiner } }; - +// define HashState to combine member hashes... see example below // ----------------------------------------------------------------------------- template class HashStateBase { @@ -299,12 +321,56 @@ template H HashStateBase::combine(H seed, const T& v, const Ts&... vs) { return HashStateBase::combine(Combiner()( - seed, phmap::Hash()(v)), vs...); + seed, phmap::Hash()(v)), + vs...); } using HashState = HashStateBase; +// ----------------------------------------------------------------------------- + +#if !defined(PHMAP_USE_ABSL_HASH) + +// define Hash for std::pair +// ------------------------- +template +struct Hash> { + size_t operator()(std::pair const& p) const noexcept { + return phmap::HashState().combine(phmap::Hash()(p.first), p.second); + } +}; + +// define Hash for std::tuple +// -------------------------- +template +struct Hash> { + size_t operator()(std::tuple const& t) const noexcept { + return _hash_helper(t); + } + +private: + template + typename std::enable_if::type + _hash_helper(const std::tuple &) const noexcept { return 0; } + + template + typename std::enable_if::type + _hash_helper(const std::tuple &t) const noexcept { + const auto &el = std::get(t); + using el_type = typename std::remove_cv::type>::type; + return Combiner()( + phmap::Hash()(el), _hash_helper(t)); + } +}; + + +#endif + + } // namespace phmap +#ifdef _MSC_VER + #pragma warning(pop) +#endif #endif // phmap_utils_h_guard_ From 244b72b006e6e45f668493b1adf718390140aad8 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 12 Jun 2020 22:43:44 -0400 Subject: [PATCH 27/52] change link order --- src/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5ea2f4146..768e81cdf 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -204,12 +204,12 @@ target_link_libraries(salmon #${SUFFARRAY_LIB64} #${GAT_SOURCE_DIR}/external/install/lib/libbwa.a m + ${FAST_MALLOC_LIB} ${LIBLZMA_LIBRARIES} ${BZIP2_LIBRARIES} ${TBB_LIBRARIES} ${LIBSALMON_LINKER_FLAGS} ${NON_APPLECLANG_LIBS} - ${FAST_MALLOC_LIB} ${LIBRT} ksw2pp ## PUFF_INTEGRATION From 2e65f36a16d46c119f37addc1cbc5d87cbbeac5f Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Sun, 14 Jun 2020 18:12:19 -0400 Subject: [PATCH 28/52] adding alevin bam tags --- src/SalmonAlevin.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index d615a2627..6cdd02d49 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -509,6 +509,7 @@ void processReadsQuasi( auto& rp = rg[i]; readLenLeft = rp.first.seq.length(); readLenRight= rp.second.seq.length(); + std::string extraBAMtags(""); bool tooShortRight = (readLenRight < (minK+alevinOpts.trimRight)); //localUpperBoundHits = 0; @@ -588,6 +589,12 @@ void processReadsQuasi( if(isUmiIdxOk){ jointHitGroup.setUMI(umiIdx.word(0)); + if (writeQuasimappings) { + extraBAMtags += "\tCR:Z:"; + extraBAMtags += *barcode; + extraBAMtags += "\tUR:Z:"; + extraBAMtags += umi; + } auto seq_len = rp.second.seq.size(); if (alevinOpts.trimRight > 0) { @@ -712,7 +719,7 @@ void processReadsQuasi( } //end-if validate mapping if (writeQuasimappings) { - writeAlignmentsToStream(rp, formatter, jointAlignments, sstream, true, true); + writeAlignmentsToStream(rp, formatter, jointAlignments, sstream, true, true, extraBAMtags); } // We've kept decoy aignments around to this point so that we can From a4d19e03697324ca494a5bbb61dc2868ed0d456d Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Sun, 14 Jun 2020 19:21:06 -0400 Subject: [PATCH 29/52] adding correcting CB --- src/SalmonAlevin.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 6cdd02d49..9a1e12f00 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -568,6 +568,7 @@ void processReadsQuasi( exit(1); } else{ barcodeIdx = trItLoc->second; + *barcode = trItLoc->first; } } // If it wasn't in the barcode map, it's not valid @@ -590,7 +591,7 @@ void processReadsQuasi( if(isUmiIdxOk){ jointHitGroup.setUMI(umiIdx.word(0)); if (writeQuasimappings) { - extraBAMtags += "\tCR:Z:"; + extraBAMtags += "\tCB:Z:"; extraBAMtags += *barcode; extraBAMtags += "\tUR:Z:"; extraBAMtags += umi; From 4572c8a5c9fd60d4e2fa77a30fc0f5c2fa2baa94 Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Sun, 14 Jun 2020 19:30:22 -0400 Subject: [PATCH 30/52] unnecessary overwriting --- src/SalmonAlevin.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 9a1e12f00..51f921827 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -568,7 +568,6 @@ void processReadsQuasi( exit(1); } else{ barcodeIdx = trItLoc->second; - *barcode = trItLoc->first; } } // If it wasn't in the barcode map, it's not valid From 32f940896bebd32ea3007cf017e1e054c25f7be7 Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Sun, 14 Jun 2020 21:30:47 -0400 Subject: [PATCH 31/52] correcting for alevin decoy dump --- src/SalmonAlevin.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 51f921827..721814eb1 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -477,7 +477,6 @@ void processReadsQuasi( } */ - size_t numDropped{0}; size_t numMappingsDropped{0}; size_t numDecoyFrags{0}; const double decoyThreshold = salmonOpts.decoyThreshold; @@ -702,10 +701,9 @@ void processReadsQuasi( } } else { numDecoyFrags += bestHitDecoy ? 1 : 0; - ++numDropped; mapType = (bestHitDecoy) ? salmon::utils::MappingType::DECOY : salmon::utils::MappingType::UNMAPPED; if (bestHitDecoy) { - salmon::mapping_utils::filterAndCollectAlignments( + salmon::mapping_utils::filterAndCollectAlignmentsDecoy( jointHits, readSubSeq.length(), readSubSeq.length(), false, // true for single-end false otherwise From f7b2643575a67110cb0ce85583728c4120ec227b Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Sun, 14 Jun 2020 21:41:12 -0400 Subject: [PATCH 32/52] reserving string size --- src/SalmonAlevin.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 721814eb1..6499be605 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -504,11 +504,14 @@ void processReadsQuasi( LibraryFormat expectedLibraryFormat = rl.format(); + std::string extraBAMtags; + size_t reserveSize { alevinOpts.protocol.barcodeLength + alevinOpts.protocol.umiLength + 12}; + extraBAMtags.reserve(reserveSize); + for (size_t i = 0; i < rangeSize; ++i) { // For all the read in this batch auto& rp = rg[i]; readLenLeft = rp.first.seq.length(); readLenRight= rp.second.seq.length(); - std::string extraBAMtags(""); bool tooShortRight = (readLenRight < (minK+alevinOpts.trimRight)); //localUpperBoundHits = 0; @@ -530,6 +533,7 @@ void processReadsQuasi( std::string umi;//, barcode; nonstd::optional barcode; nonstd::optional barcodeIdx; + extraBAMtags.clear(); bool seqOk; if (alevinOpts.protocol.end == bcEnd::FIVE || From 15fa825bba5b4c805121e4c25c0e4f56f99b0744 Mon Sep 17 00:00:00 2001 From: Avi Srivastava Date: Sun, 14 Jun 2020 21:56:17 -0400 Subject: [PATCH 33/52] making the reserve conitinoal --- src/SalmonAlevin.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 6499be605..6d78d043f 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -505,8 +505,10 @@ void processReadsQuasi( LibraryFormat expectedLibraryFormat = rl.format(); std::string extraBAMtags; - size_t reserveSize { alevinOpts.protocol.barcodeLength + alevinOpts.protocol.umiLength + 12}; - extraBAMtags.reserve(reserveSize); + if(writeQuasimappings) { + size_t reserveSize { alevinOpts.protocol.barcodeLength + alevinOpts.protocol.umiLength + 12}; + extraBAMtags.reserve(reserveSize); + } for (size_t i = 0; i < rangeSize; ++i) { // For all the read in this batch auto& rp = rg[i]; From 7fcaa6f554e9dbae615d7f7b153d127485b27284 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Tue, 16 Jun 2020 16:49:09 -0400 Subject: [PATCH 34/52] Fix possible NaN in marginal seq bias terms This adds a small prior to the marginal and probability matrices in the sequence-specific bias model to avoid the problem of having a NaN appear as a marginal probability if there were no observations with that nucleotide in the given position. Note that there are no NaNs in the underlying probability matrix as used in the VLMM itself (since we check if a probability is 0 and set to a small log prob if so). However, the computation of the marginal was done before this transformation and could be NaN of a nucelotide was not at all present in a given position. Thanks to @PRopp42 for finding and reporting this issue. --- include/SBModel.hpp | 1 + src/SBModel.cpp | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/include/SBModel.hpp b/include/SBModel.hpp index 50ba9af5b..772956dce 100644 --- a/include/SBModel.hpp +++ b/include/SBModel.hpp @@ -94,6 +94,7 @@ class SBModel { std::vector _order; std::vector _shifts; std::vector _widths; + constexpr static const double _prior_prob = 1e-10; }; #endif //__SB_MODEL_HPP__ diff --git a/src/SBModel.cpp b/src/SBModel.cpp index f635c6ee8..56b08020c 100644 --- a/src/SBModel.cpp +++ b/src/SBModel.cpp @@ -45,6 +45,7 @@ SBModel::SBModel() : _trained(false) { _marginals = Eigen::MatrixXd(4, _contextLength); _marginals.setZero(); + _marginals.array() += _prior_prob; _shifts.clear(); _widths.clear(); @@ -70,6 +71,7 @@ SBModel::SBModel() : _trained(false) { _probs = Eigen::MatrixXd(constExprPow(4, maxOrder + 1), _contextLength); // We have no intial observations _probs.setZero(); + _probs.array() += _prior_prob; } bool SBModel::writeBinary(boost::iostreams::filtering_ostream& out) const { @@ -244,7 +246,7 @@ bool SBModel::normalize() { // std::cerr << "pos = " << pos << ", marginals = " << _marginals.col(pos) // << '\n'; } - + double logSmall = std::log(1e-5); auto takeLog = [logSmall](double x) -> double { return (x > 0.0) ? std::log(x) : logSmall; From 493e3d4e4e1864c65804b386020bcbe7428d6297 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Tue, 16 Jun 2020 22:11:23 -0400 Subject: [PATCH 35/52] check the header for decoy references, and quit if we see them --- include/AlignmentLibrary.hpp | 41 +++++++++++++++++++++++++++++++++++- include/BAMQueue.tpp | 2 +- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/include/AlignmentLibrary.hpp b/include/AlignmentLibrary.hpp index ec738b1c6..1cd89f913 100644 --- a/include/AlignmentLibrary.hpp +++ b/include/AlignmentLibrary.hpp @@ -27,9 +27,10 @@ extern "C" { #include "SalmonOpts.hpp" #include "SalmonUtils.hpp" #include "SimplePosBias.hpp" -#include "SpinLock.hpp" // RapMap's with try_lock +#include "SpinLock.hpp" // From pufferfish, with try_lock #include "Transcript.hpp" #include "concurrentqueue.h" +#include "parallel_hashmap/phmap.h" // Boost includes #include @@ -38,6 +39,7 @@ extern "C" { #include #include #include +#include template class NullFragmentFilter; @@ -109,6 +111,43 @@ template class AlignmentLibrary { // Figure out aligner information from the header if we can aligner_ = salmon::bam_utils::inferAlignerFromHeader(header); + // in this case check for decoys and make a list of their names + phmap::flat_hash_set decoys; + if (aligner_ == salmon::bam_utils::AlignerDetails::PUFFERFISH) { + // for each reference + for (decltype(header->nref) i = 0; i < header->nref; ++i) { + // for each tag + SAM_hdr_tag *tag; + for (tag = header->ref[i].tag; tag; tag = tag->next) { + // if this tag marks it as a decoy + if ((tag->len == 4) and (std::strncmp(tag->str, "DS:D", 4) == 0)) { + decoys.insert(header->ref[i].name); + break; + } // end if decoy tag + + } // end for each tag + } // end for each referecne + } + + if (!decoys.empty()) { + bq->forceEndParsing(); + bq.reset(); + salmonOpts.jointLog->error( + "Salmon is being run in alignment-mode with a SAM/BAM file that contains decoy\n" + "sequences (marked as such during salmon indexing). This SAM/BAM file had {}\n" + "such sequences tagged in the header. Since alignments to decoys are not\n" + "intended for decoy-level quantification, this functionality is not currently\n" + "supported. If you wish to run salmon with this SAM/BAM file, please \n" + "filter out / remove decoy transcripts (those tagged with `DS:D`) from the \n" + "header, and all SAM/BAM records that represent alignments to decoys \n" + "(those tagged with `XT:A:D`). If you believe you are receiving this message\n" + "in error, please report this issue on GitHub.", decoys.size()); + salmonOpts.jointLog->flush(); + std::stringstream ss; + ss << "\nCannot quantify from SAM/BAM file containing decoy transcripts or alignment records!\n"; + throw std::runtime_error(ss.str()); + } + // The transcript file existed, so load up the transcripts double alpha = 0.005; // we know how many we will have, so reserve the space for diff --git a/include/BAMQueue.tpp b/include/BAMQueue.tpp index 8efb4552d..76d8e5e72 100644 --- a/include/BAMQueue.tpp +++ b/include/BAMQueue.tpp @@ -116,7 +116,7 @@ void BAMQueue::reset() { template BAMQueue::~BAMQueue() { fmt::print(stderr, "\nFreeing memory used by read queue . . . "); - parsingThread_->join(); + if (parsingThread_) { parsingThread_->join(); } fmt::print(stderr, "\nJoined parsing thread . . . "); for (auto& file : files_) { From 6e38a8d84d1f1bbdb69e3c0147b0651a68faaa46 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 19 Jun 2020 02:09:19 -0400 Subject: [PATCH 36/52] link order and remove rogue +1 --- src/CMakeLists.txt | 4 ++-- src/DistributionUtils.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 768e81cdf..8bc849b9b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -204,10 +204,8 @@ target_link_libraries(salmon #${SUFFARRAY_LIB64} #${GAT_SOURCE_DIR}/external/install/lib/libbwa.a m - ${FAST_MALLOC_LIB} ${LIBLZMA_LIBRARIES} ${BZIP2_LIBRARIES} - ${TBB_LIBRARIES} ${LIBSALMON_LINKER_FLAGS} ${NON_APPLECLANG_LIBS} ${LIBRT} @@ -216,6 +214,8 @@ target_link_libraries(salmon alevin_core ${CMAKE_DL_LIBS} ${ASAN_LIB} + ${FAST_MALLOC_LIB} + ${TBB_LIBRARIES} #ubsan ) diff --git a/src/DistributionUtils.cpp b/src/DistributionUtils.cpp index 77457b580..0924d97e3 100644 --- a/src/DistributionUtils.cpp +++ b/src/DistributionUtils.cpp @@ -41,7 +41,7 @@ void computeSmoothedEffectiveLengths(size_t maxLength, ? correctionFactors[maxLen - 1] : correctionFactors[origLen]; - double effLen = origLen - correctionFactor + 1.0; + double effLen = origLen - correctionFactor; if (effLen < 1.0) { effLen = origLen; } From 8dd16c2f0ec54cacb2fb7cda1c9ff71d36bfa785 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Mon, 22 Jun 2020 11:59:40 -0400 Subject: [PATCH 37/52] version bump --- current_version.txt | 4 ++-- doc/source/conf.py | 4 ++-- docker/Dockerfile | 2 +- docker/build_test.sh | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/current_version.txt b/current_version.txt index 8e8883241..f3da0dca9 100644 --- a/current_version.txt +++ b/current_version.txt @@ -1,3 +1,3 @@ VERSION_MAJOR 1 -VERSION_MINOR 2 -VERSION_PATCH 1 +VERSION_MINOR 3 +VERSION_PATCH 0 diff --git a/doc/source/conf.py b/doc/source/conf.py index 0ff2ead23..d3bbf85dd 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.2' +version = '1.3' # The full version, including alpha/beta/rc tags. -release = '1.2.1' +release = '1.3.0' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/docker/Dockerfile b/docker/Dockerfile index 8888731f4..4021bd5a2 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,7 +6,7 @@ MAINTAINER salmon.maintainer@gmail.com ENV PACKAGES git gcc make g++ libboost-all-dev liblzma-dev libbz2-dev \ ca-certificates zlib1g-dev libcurl4-openssl-dev curl unzip autoconf apt-transport-https ca-certificates gnupg software-properties-common wget -ENV SALMON_VERSION 1.2.1 +ENV SALMON_VERSION 1.3.0 # salmon binary will be installed in /home/salmon/bin/salmon diff --git a/docker/build_test.sh b/docker/build_test.sh index 9a001d209..5bbe6dc76 100644 --- a/docker/build_test.sh +++ b/docker/build_test.sh @@ -1,3 +1,3 @@ #! /bin/bash -SALMON_VERSION=1.2.1 +SALMON_VERSION=1.3.0 docker build --no-cache -t combinelab/salmon:${SALMON_VERSION} -t combinelab/salmon:latest . From 29dc1f72f17f68121c4210a23399158505c66dd9 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Mon, 22 Jun 2020 15:43:30 -0400 Subject: [PATCH 38/52] [CI SKIP] align commands when printing help --- src/Salmon.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/Salmon.cpp b/src/Salmon.cpp index a7b6d3980..147a15d51 100644 --- a/src/Salmon.cpp +++ b/src/Salmon.cpp @@ -56,12 +56,11 @@ int help(const std::vector& /*opts*/) { " salmon -c|--cite or \n" " salmon [--no-version-check] [-h | options]\n\n"); helpMsg.write("Commands:\n"); - helpMsg.write(" index Create a salmon index\n"); - helpMsg.write(" quant Quantify a sample\n"); - helpMsg.write(" alevin single cell analysis\n"); - helpMsg.write(" swim Perform super-secret operation\n"); - helpMsg.write( - " quantmerge Merge multiple quantifications into a single file\n"); + helpMsg.write(" index : create a salmon index\n"); + helpMsg.write(" quant : quantify a sample\n"); + helpMsg.write(" alevin : single cell analysis\n"); + helpMsg.write(" swim : perform super-secret operation\n"); + helpMsg.write(" quantmerge : merge multiple quantifications into a single file\n"); std::cout << helpMsg.str(); return 0; From d2f5436ec3cfd29d7d8599df363cfdfebd0f69b8 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Mon, 22 Jun 2020 22:18:27 -0400 Subject: [PATCH 39/52] doc --- README.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/README.md b/README.md index 0f4326f05..bbebb8cf9 100644 --- a/README.md +++ b/README.md @@ -36,8 +36,6 @@ Give salmon a try! You can find the latest binary releases [here](https://githu The current version number of the master branch of Salmon can be found [**here**](http://combine-lab.github.io/salmon/version_info/latest) -**NOTE**: Salmon works by (quasi)-mapping sequencing reads directly to the *transcriptome*. This means the Salmon index should be built on a set of target transcripts, **not** on the *genome* of the underlying organism. If indexing appears to be taking a very long time, or using a tremendous amount of memory (which it should not), please ensure that you are not attempting to build an index on the genome of your organism! - Documentation ============== From f9256beeec82584ac0bf572d4364033915ac535f Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 25 Jun 2020 11:33:24 -0400 Subject: [PATCH 40/52] Update linking order I *love* C++ linking order related issues! Thanks @gmarcais. --- src/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 8bc849b9b..7c9f9f2f1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -212,10 +212,10 @@ target_link_libraries(salmon ksw2pp ## PUFF_INTEGRATION alevin_core - ${CMAKE_DL_LIBS} ${ASAN_LIB} ${FAST_MALLOC_LIB} ${TBB_LIBRARIES} + ${CMAKE_DL_LIBS} #ubsan ) @@ -238,8 +238,8 @@ target_link_libraries(unitTests ${LIBSALMON_LINKER_FLAGS} ${NON_APPLECLANG_LIBS} ${LIBRT} - ${CMAKE_DL_LIBS} ${ASAN_LIB} + ${CMAKE_DL_LIBS} #ubsan ) From f8552fc691041931d8a42a9a62e3741ef8624c5f Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 25 Jun 2020 12:13:56 -0400 Subject: [PATCH 41/52] Update CMakeLists.txt --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 71b0fb780..899a2a2e3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,7 +68,7 @@ message(STATUS "CMAKE_BUILD_TYPE = ${CMAKE_BUILD_TYPE}") ## Set the standard required compile flags # Nov 18th --- removed -DHAVE_CONFIG_H set(REMOVE_WARNING_FLAGS "-Wno-unused-function;-Wno-unused-local-typedefs") -set(TGT_COMPILE_FLAGS "-ftree-vectorize;-funroll-loops;-fPIC;-fomit-frame-pointer;-O3;-DNDEBUG;-DSTX_NO_STD_STRING_VIEW") +set(TGT_COMPILE_FLAGS "-ftree-vectorize;-funroll-loops;-fPIC;-fomit-frame-pointer;-O3;-DNDEBUG;-DSTX_NO_STD_STRING_VIEW;-D__STDC_FORMAT_MACROS") set(TGT_WARN_FLAGS "-Wall;-Wno-unknown-pragmas;-Wno-reorder;-Wno-unused-variable;-Wreturn-type;-Werror=return-type;${REMOVE_WARNING_FLAGS}") #set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fsanitize=address") #set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") From 1403d48845a4a45c304c1d8cfec22dcf836a4cfd Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Mon, 29 Jun 2020 16:30:51 -0400 Subject: [PATCH 42/52] expose and set new selective-alignment filtering options --- include/SalmonDefaults.hpp | 5 ++- include/SalmonMappingUtils.hpp | 9 ++++- include/SalmonOpts.hpp | 3 ++ src/ProgramOptionsGenerator.cpp | 37 +++++++++++++++++--- src/Salmon.cpp | 2 +- src/SalmonQuantify.cpp | 32 ++++++++++++++++-- src/SalmonUtils.cpp | 60 +++++++++++++++++++++++++++++---- 7 files changed, 132 insertions(+), 16 deletions(-) diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index 25402f86b..083147a98 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -24,13 +24,16 @@ namespace defaults { constexpr const bool disableSA{false}; constexpr const float consensusSlack{0.0}; constexpr const double minScoreFraction{0.65}; + constexpr const double pre_merge_chain_sub_thresh{0.75}; + constexpr const double post_merge_chain_sub_thresh{0.9}; + constexpr const double orphan_chain_sub_thresh{1.0}; constexpr const double scoreExp{1.0}; constexpr const int16_t matchScore{2}; constexpr const int16_t mismatchPenalty{-4}; constexpr const int16_t gapOpenPenalty{4}; constexpr const int16_t gapExtendPenalty{2}; constexpr const int32_t dpBandwidth{15}; - constexpr const uint32_t mismatchSeedSkip{5}; + constexpr const uint32_t mismatchSeedSkip{3}; constexpr const bool disableChainingHeuristic{false}; constexpr const bool eqClassMode{false}; constexpr const bool hardFilter{false}; diff --git a/include/SalmonMappingUtils.hpp b/include/SalmonMappingUtils.hpp index 82aa93e82..df9fad683 100644 --- a/include/SalmonMappingUtils.hpp +++ b/include/SalmonMappingUtils.hpp @@ -137,7 +137,11 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem memCollector.setConsensusFraction(consensusFraction); memCollector.setHitFilterPolicy(salmonOpts.hitFilterPolicy); memCollector.setAltSkip(salmonOpts.mismatchSeedSkip); - + memCollector.setChainSubOptThresh(salmonOpts.pre_merge_chain_sub_thresh); + + double pre_merge_chain_sub_thresh; + double post_merge_chain_sub_thresh; + //Initialize ksw aligner ksw2pp::KSW2Config config; config.dropoff = -1; @@ -183,6 +187,9 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem mpol.noDovetail = !salmonOpts.allowDovetail; aconf.noDovetail = mpol.noDovetail; + mpol.setPostMergeChainSubThresh(salmonOpts.post_merge_chain_sub_thresh); + mpol.setOrphanChainSubThresh(salmonOpts.orphan_chain_sub_thresh); + return true; } diff --git a/include/SalmonOpts.hpp b/include/SalmonOpts.hpp index 73121635e..8b2473dd7 100644 --- a/include/SalmonOpts.hpp +++ b/include/SalmonOpts.hpp @@ -260,6 +260,9 @@ struct SalmonOpts { bool validateMappings; bool disableSA{false}; // this cannot be done right now. float consensusSlack; + double pre_merge_chain_sub_thresh; + double post_merge_chain_sub_thresh; + double orphan_chain_sub_thresh; bool disableChainingHeuristic; bool disableAlignmentCache; double minScoreFraction; diff --git a/src/ProgramOptionsGenerator.cpp b/src/ProgramOptionsGenerator.cpp index 6bca41daf..961f2bd0b 100644 --- a/src/ProgramOptionsGenerator.cpp +++ b/src/ProgramOptionsGenerator.cpp @@ -102,12 +102,39 @@ namespace salmon { "[*deprecated* (no effect; selective-alignment is the default)]") ("consensusSlack", po::value(&(sopt.consensusSlack))->default_value(salmon::defaults::consensusSlack), - "[selective-alignment mode only] : The amount of slack allowed in the quasi-mapping consensus " - "mechanism. Normally, a transcript must cover all hits to be considered for mapping. " - "If this is set to a fraction, X, greater than 0 (and in [0,1)), then a transcript can fail to cover up to " - "(100 * X)% of the hits before it is discounted as a mapping candidate. The default value of this option " - "is 0.2 if --validateMappings is given and 0 otherwise." + "[selective-alignment mode only] : The amount of slack allowed in the selective-alignment " + "filtering mechanism. If this is set to a fraction, X, greater than 0 (and in [0,1)), then " + "uniMEM chains with scores below (100 * X)% of the best chain score for a read, and read pairs " + "with a sum of chain scores below (100 * X)% of the best chain score for a read pair " + "will be discounted as a mapping candidates. The default value of this option is 0.35." ) + ("preMergeChainSubThresh", + po::value(&(sopt.pre_merge_chain_sub_thresh))->default_value(salmon::defaults::pre_merge_chain_sub_thresh), + "[selective-alignment mode only] : The threshold of sub-optimal chains, compared to the best chain on a given " + "target, that will be retained and passed to the next phase of mapping. Specifically, if the best chain " + "for a read (or read-end in paired-end mode) to target t has score X_t, then all chains for this read with " + "score >= X_t * preMergeChainSubThresh will be retained and passed to subsequent mapping phases. This value " + "must be in the range [0, 1]." + ) + ("postMergeChainSubThresh", + po::value(&(sopt.post_merge_chain_sub_thresh))->default_value(salmon::defaults::post_merge_chain_sub_thresh), + "[selective-alignment mode only] : The threshold of sub-optimal chain pairs, compared to the best chain pair " + "on a given target, that will be retained and passed to the next phase of mapping. This is different than " + "preMergeChainSubThresh, because this is applied to pairs of chains (from the ends of paired-end reads) after " + "merging (i.e. after checking concordancy constraints etc.). Specifically, if the best chain pair " + "to target t has score X_t, then all chain pairs for this read pair with score " + ">= X_t * postMergeChainSubThresh will be retained and passed to subsequent mapping phases. This value " + "must be in the range [0, 1]. Note: This option is only meaningful for paired-end libraries, and is ignored " + "for single-end libraries." + ) + ("orphanChainSubThresh", + po::value(&(sopt.orphan_chain_sub_thresh))->default_value(salmon::defaults::orphan_chain_sub_thresh), + "[selective-alignment mode only] : This threshold sets a global sub-optimality threshold for chains " + "corresponding to orphan mappings. That is, if the merging procedure results in no concordant mappings " + "then only orphan mappings with a chain score >= orphanChainSubThresh * bestChainScore will be " + "retained and passed to subsequent mapping phases. This value must be in the range [0, 1]. Note: This " + "option is only meaningful for paired-end libraries, and is ignored for single-end libraries." + ) ("scoreExp", po::value(&sopt.scoreExp)->default_value(salmon::defaults::scoreExp), "[selective-alignment mode only] : The factor by which sub-optimal alignment scores are " diff --git a/src/Salmon.cpp b/src/Salmon.cpp index 147a15d51..4274e0575 100644 --- a/src/Salmon.cpp +++ b/src/Salmon.cpp @@ -77,7 +77,7 @@ int dualModeMessage() { alignment-based algorithm will be used, otherwise the algorithm for quantifying from raw reads will be used. - to view the help for salmon's quasi-mapping-based mode, use the command + to view the help for salmon's selective-alignment-based mode, use the command salmon quant --help-reads diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 9392a1b52..6d3962059 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -954,6 +954,24 @@ void processReads( ); hctr.numMappedAtLeastAKmer += (leftHits.size() > 0 || rightHits.size() > 0) ? 1 : 0; + /* + salmonOpts.jointLog->info("\n\n mappings for left end \n\n"); + + for (auto&& h : leftHits) { + salmonOpts.jointLog->info("hit to : {}", qidx->refName( h.first )); + for (auto&& mc : *h.second) { + salmonOpts.jointLog->info("\t ori : {}, pos : {}, score : {}", (mc.isFw ? "fw" : "rc"), mc.getTrFirstHitPos(), mc.score); + } + } + + salmonOpts.jointLog->info("\n\n mappings for right end \n\n"); + for (auto&& h : rightHits) { + salmonOpts.jointLog->info("hit to : {}", qidx->refName( h.first )); + for (auto&& mc : *h.second) { + salmonOpts.jointLog->info("\t ori : {}, pos : {}, score : {}", (mc.isFw ? "fw" : "rc"), mc.getTrFirstHitPos(), mc.score); + } + } + */ // TODO : PF_INTEGRATION /* @@ -1014,6 +1032,16 @@ void processReads( upperBoundHits += (jointHits.size() > 0); } + /* + salmonOpts.jointLog->info("\n\n mappings for joined ends \n\n"); + for (auto&& h : jointHits) { + salmonOpts.jointLog->info("hit to : {}", qidx->refName( h.tid )); + salmonOpts.jointLog->info("\t lpos : {}, rpos : {}, score : {}", h.leftClust->getTrFirstHitPos(), + h.rightClust->getTrFirstHitPos(), h.coverage()); + } + */ + + // FIXME: This clears the alignment group, but that contains nothing // at this point. We should either check only once we are at the alignment // phase (and therefore filter nothing based on pre-alignment hits), or @@ -2387,7 +2415,7 @@ int salmonQuantify(int argc, const char* argv[]) { auto hstring = R"( Quant ========== -Perform dual-phase, mapping-based estimation of +Perform dual-phase, selective-alignment-based estimation of transcript abundance from RNA-seq reads )"; std::cout << hstring << std::endl; @@ -2403,7 +2431,7 @@ transcript abundance from RNA-seq reads } std::stringstream commentStream; - commentStream << "### salmon (mapping-based) v" << salmon::version << "\n"; + commentStream << "### salmon (selective-alignment-based) v" << salmon::version << "\n"; commentStream << "### [ program ] => salmon \n"; commentStream << "### [ command ] => quant \n"; for (auto& opt : orderedOptions.options) { diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index 291dbb8a8..006cc895b 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -1423,23 +1423,24 @@ std::string getCurrentTimeAsString() { "`--validateMappings` is generally recommended.\n"); } + bool is_pe_library = (numLeft + numRight > 0); + bool is_se_library = (numUnpaired > 0); + // currently there is some strange use for this in alevin, I think ... // check with avi. - if (numLeft + numRight > 0 and numUnpaired > 0) { + if (is_pe_library and is_se_library) { sopt.jointLog->warn("You seem to have passed in both un-paired reads and paired-end reads. " "It is not currently possible to quantify hybrid library types in salmon."); } - - if (numLeft + numRight > 0) { + if (is_pe_library) { if (numLeft != numRight) { sopt.jointLog->error("You passed paired-end files to salmon, but you passed {} files to --mates1 " "and {} files to --mates2. You must pass the same number of files to both flags", numLeft, numRight); return false; } - } - + } auto checkScoreValue = [&sopt](int16_t score, std::string sname) -> bool { using score_t = int8_t; @@ -1510,7 +1511,7 @@ std::string getCurrentTimeAsString() { sopt.useRangeFactorization = true; } - // If the consensus slack was not set explicitly, then it defaults to 0.2 with + // If the consensus slack was not set explicitly, then it defaults to 0.35 with // validateMappings bool consensusSlackExplicit = !vm["consensusSlack"].defaulted(); if (!consensusSlackExplicit) { @@ -1520,6 +1521,52 @@ std::string getCurrentTimeAsString() { "Setting consensusSlack to {}.", sopt.consensusSlack); } + bool pre_merge_chain_sub_thresh_explicit = !vm["preMergeChainSubThresh"].defaulted(); + bool post_merge_chain_sub_thresh_explicit = !vm["postMergeChainSubThresh"].defaulted(); + bool orphan_chain_sub_thresh_explicit = !vm["orphanChainSubThresh"].defaulted(); + + // for a single-end library, we set + if ( is_se_library ) { + + // The default of preMergeChainSubThresh for single-end libraries is 1.0, so set that here + if (!pre_merge_chain_sub_thresh_explicit) { + sopt.pre_merge_chain_sub_thresh = 1.0; + } + + // for single-end libraries, postMergeChainSubThresh and orphanChainSubThresh are meaningless + if (post_merge_chain_sub_thresh_explicit) { + sopt.jointLog->warn("The postMergeChainSubThresh is not meaningful for single-end " + "libraries. Setting this value to 1.0 and ignoring"); + } + if (orphan_chain_sub_thresh_explicit) { + sopt.jointLog->warn("The orphanChainSubThresh is not meaningful for single-end " + "libraries. Setting this value to 1.0 and ignoring"); + } + sopt.post_merge_chain_sub_thresh = 1.0; + sopt.orphan_chain_sub_thresh = 1.0; + } + + // value range check for filters + // pre-merge + if (sopt.pre_merge_chain_sub_thresh < 0 or sopt.pre_merge_chain_sub_thresh > 1.0) { + sopt.jointLog->error("You set preMergeChainSubThresh as {}, but it must in [0,1].", + sopt.pre_merge_chain_sub_thresh); + return false; + } + // post-merge + if (sopt.post_merge_chain_sub_thresh < 0 or sopt.post_merge_chain_sub_thresh > 1.0) { + sopt.jointLog->error("You set postMergeChainSubThresh as {}, but it must in [0,1].", + sopt.post_merge_chain_sub_thresh); + return false; + } + // orphan + if (sopt.orphan_chain_sub_thresh < 0 or sopt.orphan_chain_sub_thresh > 1.0) { + sopt.jointLog->error("You set orphanChainSubThresh as {}, but it must in [0,1].", + sopt.orphan_chain_sub_thresh); + return false; + } + + if (sopt.mimicBT2 and sopt.mimicStrictBT2) { sopt.jointLog->error("You passed both the --mimicBT2 and --mimicStrictBT2 parameters. These are mutually exclusive. " "Please select only one of these flags."); @@ -1875,6 +1922,7 @@ bool processQuantOptions(SalmonOpts& sopt, // std::make_shared(rawConsoleSink); auto consoleSink = std::make_shared(); + consoleSink->set_color(spdlog::level::warn, consoleSink->magenta); auto consoleLog = spdlog::create("stderrLog", {consoleSink}); auto fileLog = spdlog::create("fileLog", {fileSink}); std::vector sinks{consoleSink, fileSink}; From a2909da57f7ba82305ab33bf1c01e6c9f05150f4 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Mon, 29 Jun 2020 16:55:37 -0400 Subject: [PATCH 43/52] remove unused member variables --- include/SalmonMappingUtils.hpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/include/SalmonMappingUtils.hpp b/include/SalmonMappingUtils.hpp index df9fad683..a7c21c787 100644 --- a/include/SalmonMappingUtils.hpp +++ b/include/SalmonMappingUtils.hpp @@ -139,9 +139,6 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem memCollector.setAltSkip(salmonOpts.mismatchSeedSkip); memCollector.setChainSubOptThresh(salmonOpts.pre_merge_chain_sub_thresh); - double pre_merge_chain_sub_thresh; - double post_merge_chain_sub_thresh; - //Initialize ksw aligner ksw2pp::KSW2Config config; config.dropoff = -1; From 90cac952acda42702ac11c11ab1b71d4916dd117 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Wed, 1 Jul 2020 21:32:29 -0400 Subject: [PATCH 44/52] changes relevant for updated pufferfish --- include/SalmonMappingUtils.hpp | 3 +++ src/SalmonAlevin.cpp | 1 + src/SalmonQuantify.cpp | 2 ++ 3 files changed, 6 insertions(+) diff --git a/include/SalmonMappingUtils.hpp b/include/SalmonMappingUtils.hpp index a7c21c787..e8ec4f8c5 100644 --- a/include/SalmonMappingUtils.hpp +++ b/include/SalmonMappingUtils.hpp @@ -156,6 +156,9 @@ inline bool initMapperSettings(SalmonOpts& salmonOpts, MemCollector& mem aconf.refExtendLength = 20; aconf.fullAlignment = salmonOpts.fullLengthAlignment; + aconf.mismatchPenalty = salmonOpts.mismatchPenalty; + aconf.bestStrata = false; + aconf.decoyPresent = false; aconf.matchScore = salmonOpts.matchScore; aconf.gapExtendPenalty = salmonOpts.gapExtendPenalty; aconf.gapOpenPenalty = salmonOpts.gapOpenPenalty; diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index 6d78d043f..ed7914f7b 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -663,6 +663,7 @@ void processReadsQuasi( // adding validate mapping code if (tryAlign and !jointHits.empty()) { puffaligner.clear(); + puffaligner.getScoreStatus().reset(); msi.clear(jointHits.size()); size_t idx{0}; diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index 6d3962059..ed508578d 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -1104,6 +1104,7 @@ void processReads( if (tryAlign and !jointHits.empty()) { // clear the aligner for this read puffaligner.clear(); + puffaligner.getScoreStatus().reset(); msi.clear(jointHits.size()); size_t idx{0}; @@ -1717,6 +1718,7 @@ void processReads( // clear the aligner for this read puffaligner.clear(); + puffaligner.getScoreStatus().reset(); msi.clear(jointHits.size()); size_t idx{0}; From b5e07e4889a3b334510d8b5d9907851d14d13bd8 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 2 Jul 2020 13:49:55 -0400 Subject: [PATCH 45/52] update the libgff used --- CMakeLists.txt | 23 ++++++++++++++++------- src/CMakeLists.txt | 2 ++ src/SalmonUtils.cpp | 2 +- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 899a2a2e3..a413631e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -723,14 +723,22 @@ message("TBB_LIBRARIES = ${TBB_LIBRARIES}") #message("TBB_LIBRARY_DIRS ${TBB_LIBRARY_DIRS}") #message("TBB_LIBRARIES ${TBB_LIBRARIES} ") -find_package(libgff) -if(NOT LIBGFF_FOUND) +find_package(libgff 2.0.0 +HINTS ${LIB_GFF_PATH} ${GFF_ROOT} +) +if(libgff_FOUND) + message(STATUS "libgff ver. ${LIB_GFF_VERSION} found.") + message(STATUS " include: ${LIB_GFF_INCLUDE_DIR}") + message(STATUS " lib : ${LIB_GFF_LIBRARY_DIR}") +endif() + +if(NOT libgff_FOUND) message("Build system will compile libgff") message("==================================================================") externalproject_add(libgff DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external - DOWNLOAD_COMMAND curl -k -L https://github.com/COMBINE-lab/libgff/archive/v1.2.tar.gz -o libgff.tgz && - ${SHASUM} bfabf143da828e8db251104341b934458c19d3e3c592d418d228de42becf98eb libgff.tgz && + DOWNLOAD_COMMAND curl -k -L https://github.com/COMBINE-lab/libgff/archive/v2.0.0.tar.gz -o libgff.tgz && + ${SHASUM} 7656b19459a7ca7d2fd0fcec4f2e0fd0deec1b4f39c703a114e8f4c22d82a99c libgff.tgz && tar -xzvf libgff.tgz ## #URL https://github.com/COMBINE-lab/libgff/archive/v1.1.tar.gz @@ -738,11 +746,11 @@ if(NOT LIBGFF_FOUND) #URL_HASH SHA1=37b3147d78391d1fabbe6a0df313fbf516abbc6f #TLS_VERIFY FALSE ## - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-1.2 + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-2.0.0 #UPDATE_COMMAND sh -c "mkdir -p /build" INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/install - BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-1.2/build - CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${CMAKE_CURRENT_SOURCE_DIR}/external/install + BINARY_DIR ${CMAKE_CURRENT_SOURCE_DIR}/external/libgff-2.0.0/build + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX:PATH= -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER} ) externalproject_add_step(libgff makedir COMMAND mkdir -p /build @@ -750,6 +758,7 @@ if(NOT LIBGFF_FOUND) DEPENDEES download DEPENDERS configure) set(FETCHED_GFF TRUE) + set(LIB_GFF_PATH ${CMAKE_CURRENT_SOURCE_DIR}/external/install) endif() # Because of the way that Apple has changed SIP diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7c9f9f2f1..e0309ec76 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -10,6 +10,7 @@ ${Boost_INCLUDE_DIRS} ${GAT_SOURCE_DIR}/external/install/include ${GAT_SOURCE_DIR}/external/install/include/pufferfish ${GAT_SOURCE_DIR}/external/install/include/pufferfish/digestpp +${LIB_GFF_INCLUDE_DIR} #${GAT_SOURCE_DIR}/external/install/include/rapmap #${GAT_SOURCE_DIR}/external/install/include/rapmap/digestpp ${ICU_INC_DIRS} @@ -102,6 +103,7 @@ ${Boost_LIBRARY_DIRS} ${TBB_LIBRARY_DIRS} ${LAPACK_LIBRARY_DIR} ${BLAS_LIBRARY_DIR} +${LIB_GFF_LIBRARY_DIR} ) message("TBB_LIBRARIES = ${TBB_LIBRARIES}") diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index 006cc895b..22962a545 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -1075,7 +1075,7 @@ TranscriptGeneMap transcriptGeneMapFromGTF(const std::string& fname, auto logger = spdlog::get("jointLog"); // Use GffReader to read the file - GffReader reader(const_cast(fname.c_str())); + GffReader reader(const_cast(fname.c_str()), true, false); // Remember the optional attributes reader.readAll(true); From 0c61c79bc64fdd67837168ef013cf6fcbce2e524 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 2 Jul 2020 14:08:32 -0400 Subject: [PATCH 46/52] get rid of unused variable warning --- include/SalmonMappingUtils.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/include/SalmonMappingUtils.hpp b/include/SalmonMappingUtils.hpp index e8ec4f8c5..8f4c8c261 100644 --- a/include/SalmonMappingUtils.hpp +++ b/include/SalmonMappingUtils.hpp @@ -370,6 +370,7 @@ inline void filterAndCollectAlignmentsDecoy( // regardless of the the status of hardFilter (i.e. no sub-optimal decoy mappings will be reported). (void) hardFilter; (void) minAlnProb; +(void) scoreExp; double estAlnProb = 1.0; //std::exp(-scoreExp * 0.0); for (auto& idxTxp : msi.best_decoy_hits) { int32_t ctr = idxTxp.first; From 3596e6b87374040c20081d6eae389766d60b7aa1 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 2 Jul 2020 21:50:03 -0400 Subject: [PATCH 47/52] Update defaults for alevin Make sure that alevin is treated as single-end by default (with respect to the new chaining filters). --- include/SalmonDefaults.hpp | 2 +- src/SalmonAlevin.cpp | 4 ++-- src/SalmonQuantify.cpp | 8 ++++---- src/SalmonUtils.cpp | 11 +++++++---- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index 083147a98..f810e75c0 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -26,7 +26,7 @@ namespace defaults { constexpr const double minScoreFraction{0.65}; constexpr const double pre_merge_chain_sub_thresh{0.75}; constexpr const double post_merge_chain_sub_thresh{0.9}; - constexpr const double orphan_chain_sub_thresh{1.0}; + constexpr const double orphan_chain_sub_thresh{0.9}; constexpr const double scoreExp{1.0}; constexpr const int16_t matchScore{2}; constexpr const int16_t mismatchPenalty{-4}; diff --git a/src/SalmonAlevin.cpp b/src/SalmonAlevin.cpp index ed7914f7b..5446493ce 100644 --- a/src/SalmonAlevin.cpp +++ b/src/SalmonAlevin.cpp @@ -667,12 +667,12 @@ void processReadsQuasi( msi.clear(jointHits.size()); size_t idx{0}; - bool isMultimapping = (jointHits.size() > 1); + bool is_multimapping = (jointHits.size() > 1); for (auto &&jointHit : jointHits) { // for alevin, currently, we need these to have a mate status of PAIRED_END_RIGHT jointHit.mateStatus = MateStatus::PAIRED_END_RIGHT; - auto hitScore = puffaligner.calculateAlignments(readSubSeq, jointHit, hctr, isMultimapping, false); + auto hitScore = puffaligner.calculateAlignments(readSubSeq, jointHit, hctr, is_multimapping, false); bool validScore = (hitScore != invalidScore); numMappingsDropped += validScore ? 0 : 1; auto tid = qidx->getRefId(jointHit.tid); diff --git a/src/SalmonQuantify.cpp b/src/SalmonQuantify.cpp index ed508578d..98cf15f35 100644 --- a/src/SalmonQuantify.cpp +++ b/src/SalmonQuantify.cpp @@ -1108,10 +1108,10 @@ void processReads( msi.clear(jointHits.size()); size_t idx{0}; - bool isMultimapping = (jointHits.size() > 1); + bool is_multimapping = (jointHits.size() > 1); for (auto &&jointHit : jointHits) { - auto hitScore = puffaligner.calculateAlignments(rp.first.seq, rp.second.seq, jointHit, hctr, isMultimapping, false); + auto hitScore = puffaligner.calculateAlignments(rp.first.seq, rp.second.seq, jointHit, hctr, is_multimapping, false); bool validScore = (hitScore != invalidScore); numMappingsDropped += validScore ? 0 : 1; auto tid = qidx->getRefId(jointHit.tid); @@ -1722,10 +1722,10 @@ void processReads( msi.clear(jointHits.size()); size_t idx{0}; - bool isMultimapping = (jointHits.size() > 1); + bool is_multimapping = (jointHits.size() > 1); for (auto &&jointHit : jointHits) { - auto hitScore = puffaligner.calculateAlignments(rp.seq, jointHit, hctr, isMultimapping, false); + auto hitScore = puffaligner.calculateAlignments(rp.seq, jointHit, hctr, is_multimapping, false); bool validScore = (hitScore != invalidScore); numMappingsDropped += validScore ? 0 : 1; auto tid = qidx->getRefId(jointHit.tid); diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index 22962a545..d49000546 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -1525,8 +1525,9 @@ std::string getCurrentTimeAsString() { bool post_merge_chain_sub_thresh_explicit = !vm["postMergeChainSubThresh"].defaulted(); bool orphan_chain_sub_thresh_explicit = !vm["orphanChainSubThresh"].defaulted(); - // for a single-end library, we set - if ( is_se_library ) { + // for a single-end library (or effectively so by being single-cell), we set + // pre_merge_chain_sub_thresh to 1.0 by default + if ( is_se_library or sopt.alevinMode ) { // The default of preMergeChainSubThresh for single-end libraries is 1.0, so set that here if (!pre_merge_chain_sub_thresh_explicit) { @@ -1536,11 +1537,13 @@ std::string getCurrentTimeAsString() { // for single-end libraries, postMergeChainSubThresh and orphanChainSubThresh are meaningless if (post_merge_chain_sub_thresh_explicit) { sopt.jointLog->warn("The postMergeChainSubThresh is not meaningful for single-end " - "libraries. Setting this value to 1.0 and ignoring"); + "(or effectively single-end — e.g. tagged-end single-cell) libraries. Setting this value " + "to 1.0 and ignoring"); } if (orphan_chain_sub_thresh_explicit) { sopt.jointLog->warn("The orphanChainSubThresh is not meaningful for single-end " - "libraries. Setting this value to 1.0 and ignoring"); + "(or effectively single-end — e.g. tagged-end single-cell) libraries. Setting this value " + "to 1.0 and ignoring"); } sopt.post_merge_chain_sub_thresh = 1.0; sopt.orphan_chain_sub_thresh = 1.0; From 0aae20d050e60c543cb20270a335a713e29c8f96 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Thu, 2 Jul 2020 23:14:11 -0400 Subject: [PATCH 48/52] change default --- include/SalmonDefaults.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index f810e75c0..69831c665 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -26,7 +26,7 @@ namespace defaults { constexpr const double minScoreFraction{0.65}; constexpr const double pre_merge_chain_sub_thresh{0.75}; constexpr const double post_merge_chain_sub_thresh{0.9}; - constexpr const double orphan_chain_sub_thresh{0.9}; + constexpr const double orphan_chain_sub_thresh{0.95}; constexpr const double scoreExp{1.0}; constexpr const int16_t matchScore{2}; constexpr const int16_t mismatchPenalty{-4}; From 1446ee31e6dc7cf4b25b77339cbb0abd2b6db1b6 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 3 Jul 2020 16:53:51 -0400 Subject: [PATCH 49/52] wording changes, change default gap open --- include/SalmonDefaults.hpp | 2 +- src/ProgramOptionsGenerator.cpp | 6 +++--- src/SalmonUtils.cpp | 6 +++--- tests/basic_alevin_test.sh | 3 ++- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index 69831c665..dccd7083d 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -30,7 +30,7 @@ namespace defaults { constexpr const double scoreExp{1.0}; constexpr const int16_t matchScore{2}; constexpr const int16_t mismatchPenalty{-4}; - constexpr const int16_t gapOpenPenalty{4}; + constexpr const int16_t gapOpenPenalty{8}; constexpr const int16_t gapExtendPenalty{2}; constexpr const int32_t dpBandwidth{15}; constexpr const uint32_t mismatchSeedSkip{3}; diff --git a/src/ProgramOptionsGenerator.cpp b/src/ProgramOptionsGenerator.cpp index 961f2bd0b..0d89c9e88 100644 --- a/src/ProgramOptionsGenerator.cpp +++ b/src/ProgramOptionsGenerator.cpp @@ -84,7 +84,7 @@ namespace salmon { mapspec.add_options() ("discardOrphansQuasi", po::bool_switch(&(sopt.discardOrphansQuasi))->default_value(salmon::defaults::discardOrphansQuasi), - "[selective-alignment mode only] : Discard orphan mappings in quasi-mapping " + "[selective-alignment mode only] : Discard orphan mappings in selective-alignment " "mode. If this flag is passed " "then only paired mappings will be considered toward quantification " "estimates. The default behavior is " @@ -264,7 +264,7 @@ namespace salmon { ("writeMappings,z", po::value(&sopt.qmFileName) ->default_value(salmon::defaults::quasiMappingDefaultFile) ->implicit_value(salmon::defaults::quasiMappingImplicitFile), - "If this option is provided, then the quasi-mapping " + "If this option is provided, then the selective-alignment " "results will be written out in SAM-compatible " "format. By default, output will be directed to " "stdout, but an alternative file name can be " @@ -358,7 +358,7 @@ namespace salmon { "Run naive per equivalence class deduplication, generating only total number of UMIs") ( "noDedup", po::bool_switch()->default_value(alevin::defaults::noDedup), - "Stops the pipeline after CB sequence correction and quasi-mapping reads."); + "Stops the pipeline after CB sequence correction and selective-alignment of reads."); return alevindevs; } diff --git a/src/SalmonUtils.cpp b/src/SalmonUtils.cpp index d49000546..6c55178d7 100644 --- a/src/SalmonUtils.cpp +++ b/src/SalmonUtils.cpp @@ -1705,7 +1705,7 @@ bool createAuxMapLoggers_(SalmonOpts& sopt, sopt.orphanLinkLog = outLog; } - // Determine what we'll do with quasi-mapping results + // Determine what we'll do with selective-alignment results bool writeQuasimappings = (sopt.qmFileName != ""); if (writeQuasimappings) { @@ -1732,7 +1732,7 @@ bool createAuxMapLoggers_(SalmonOpts& sopt, // Make sure file opened successfully. if (!sopt.qmFile.is_open()) { jointLog->error( - "Could not create file for writing quasi-mappings [{}]", + "Could not create file for writing selective-alignments [{}]", sopt.qmFileName); return false; } @@ -1906,7 +1906,7 @@ bool processQuantOptions(SalmonOpts& sopt, bfs::path indexDirectory(vm["index"].as()); sopt.indexDirectory = indexDirectory; - // Determine what we'll do with quasi-mapping results + // Determine what we'll do with selective-alignment results bool writeQuasimappings = (sopt.qmFileName != ""); // make it larger if we're writing mappings or diff --git a/tests/basic_alevin_test.sh b/tests/basic_alevin_test.sh index f1464dd7c..2bf1fd9e3 100755 --- a/tests/basic_alevin_test.sh +++ b/tests/basic_alevin_test.sh @@ -1,9 +1,10 @@ ALVBIN=$1 #"/mnt/scratch6/salmon_ci/COMBINE-lab/salmon/master/build/salmon-latest_linux_x86_64/bin/salmon" +BASEDIR="/mnt/scratch6/avi/alevin/alevin" OUT=$PWD tfile=$(mktemp /tmp/foo.XXXXXXXXX) -/usr/bin/time -o $tfile $ALVBIN alevin -lISR --chromium -1 /mnt/scratch5/avi/alevin/data/10x/v2/mohu/100/all_bcs.fq -2 /mnt/scratch5/avi/alevin/data/10x/v2/mohu/100/all_reads.fq -o $OUT/prediction -i /mnt/scratch5/avi/alevin/data/mohu/salmon_index -p 20 --tgMap /mnt/scratch5/avi/alevin/data/mohu/gtf/txp2gene.tsv --dumpMtx --no-version-check --dumpFeatures --dumpArborescence #--whitelist ./alevin_test_data/alevin/quants_mat_rows.txt +/usr/bin/time -o $tfile $ALVBIN alevin -lISR --chromium -1 ${BASEDIR}/data/10x/v2/mohu/100/all_bcs.fq -2 ${BASEDIR}/data/10x/v2/mohu/100/all_reads.fq -o $OUT/prediction -i ${BASEDIR}/data/mohu/salmon_index -p 20 --tgMap ${BASEDIR}/data/mohu/gtf/txp2gene.tsv --dumpMtx --no-version-check --dumpFeatures --dumpArborescence --writeMappings=$OUT/prediction/with_bug.sam #--whitelist ./alevin_test_data/alevin/quants_mat_rows.txt #--dumpBfh --whitelist /mnt/scratch5/avi/alevin/bin/salmon/tests/whitelist.txt #--dumpUmiGraph --numCellBootstraps 100 --dumpBfh --dumpBarcodeEq --dumpMtx --expectCells 1001 --end 6 --umiLength 10 --barcodeLength 16 From 5b6b7b43307bf20af67e6a684e43cc36775f5902 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 3 Jul 2020 21:59:47 -0400 Subject: [PATCH 50/52] change alevin log warning color --- src/AlevinUtils.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/AlevinUtils.cpp b/src/AlevinUtils.cpp index e320627c0..5130a5f75 100644 --- a/src/AlevinUtils.cpp +++ b/src/AlevinUtils.cpp @@ -563,6 +563,7 @@ namespace alevin { auto logPath = aopt.outputDirectory / "alevin.log"; auto fileSink = std::make_shared(logPath.string(), true); auto consoleSink = std::make_shared(); + consoleSink->set_color(spdlog::level::warn, consoleSink->magenta); std::vector sinks{consoleSink, fileSink}; aopt.jointLog = spdlog::create("alevinLog", std::begin(sinks), std::end(sinks)); From bfc1a4ae06f9be26ebd7996997e136ca0dc45d17 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 3 Jul 2020 23:08:06 -0400 Subject: [PATCH 51/52] chagne default again --- include/SalmonDefaults.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/SalmonDefaults.hpp b/include/SalmonDefaults.hpp index dccd7083d..1b4b079ca 100644 --- a/include/SalmonDefaults.hpp +++ b/include/SalmonDefaults.hpp @@ -30,7 +30,7 @@ namespace defaults { constexpr const double scoreExp{1.0}; constexpr const int16_t matchScore{2}; constexpr const int16_t mismatchPenalty{-4}; - constexpr const int16_t gapOpenPenalty{8}; + constexpr const int16_t gapOpenPenalty{6}; constexpr const int16_t gapExtendPenalty{2}; constexpr const int32_t dpBandwidth{15}; constexpr const uint32_t mismatchSeedSkip{3}; From 037106a7dab0632b40ad9da338b27b1015a66967 Mon Sep 17 00:00:00 2001 From: Rob Patro Date: Fri, 3 Jul 2020 23:20:12 -0400 Subject: [PATCH 52/52] pull from tagged pufferfish --- scripts/fetchPufferfish.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/fetchPufferfish.sh b/scripts/fetchPufferfish.sh index 9f6cc880b..ca299e178 100755 --- a/scripts/fetchPufferfish.sh +++ b/scripts/fetchPufferfish.sh @@ -22,10 +22,10 @@ if [ -d ${INSTALL_DIR}/src/pufferfish ] ; then rm -fr ${INSTALL_DIR}/src/pufferfish fi -#SVER=salmon-v1.2.1 -SVER=develop +SVER=salmon-v1.3.0 +#SVER=develop -EXPECTED_SHA256=da51713e54cf426524a2a1da7de2273cea7bf1f4089abbce22fcaa8f59e493cc +EXPECTED_SHA256=0176b2ec5fc45bbf68c60b5845fead28e63db72f91ff93499d67e7a571167fdf mkdir -p ${EXTERNAL_DIR} curl -k -L https://github.com/COMBINE-lab/pufferfish/archive/${SVER}.zip -o ${EXTERNAL_DIR}/pufferfish.zip